import warnings from logging import getLogger from queue import Queue from threading import Thread, Semaphore from typing import * import numpy as np from ..typing_ import * from ..utils import (minibatch_slices_iterator, AutoInitAndCloseable, NOT_SET, GeneratorIterator, to_number_or_numpy) __all__ = [ 'DataStream', 'UserGeneratorDataStream', 'ArraysDataStream', 'IntSeqDataStream', 'GeneratorFactoryDataStream', 'GatherDataStream', 'MapperDataStream', 'ThreadingDataStream', ] def map_to_tuple(fn: Callable[[Any], TObject], seq: Iterable[Any]): return tuple(fn(s) for s in seq) def to_data_shapes(data_shapes) -> Tuple[ArrayShape, ...]: return map_to_tuple(lambda x: map_to_tuple(int, x), data_shapes) def to_readonly_array(arr: Array) -> Array: arr = np.asarray(arr) arr.setflags(write=False) return arr def ensure_batch_is_tuple(batch: Union[Array, ArrayTupleOrList] ) -> ArrayTuple: if not isinstance(batch, (tuple, list)): batch = (batch,) else: batch = tuple(batch) return batch class DataStream(object): """ Class to construct mini-batch data iterators. Constructing Data Streams ========================= All :class:`DataStream` subclasses shipped by `ml_essentials` can be constructed via factory methods of this base class. To construct a data stream from numpy arrays, you may: >>> x = np.arange(5, dtype=np.int32) >>> y = x ** 2 >>> stream = DataStream.arrays([x, y], batch_size=3) >>> for [a, b] in stream: ... print(a, b) [0 1 2] [0 1 4] [3 4] [ 9 16] To construct a integer sequence data stream, you may: >>> stream = DataStream.int_seq(start=1, stop=10, step=2, batch_size=3) >>> for [a] in stream: ... print(a) [1 3 5] [7 9] To gather multiple data streams into one, you may: >>> stream_1 = DataStream.int_seq(5, batch_size=3) >>> stream_2 = DataStream.int_seq(-5, step=-1, batch_size=3) >>> for [a] in stream_1: ... print(a) [0 1 2] [3 4] >>> for [b] in stream_2: ... print(b) [ 0 -1 -2] [-3 -4] >>> stream = DataStream.gather([stream_1, stream_2]) >>> for [a, b] in stream: ... print(a, b) [0 1 2] [ 0 -1 -2] [3 4] [-3 -4] To turn an arbitrary mini-batch generator factory function into a data stream, you may: >>> def data_generator(): ... for i in range(2): ... yield np.arange(i * 3, (i + 1) * 3, dtype=np.int32) >>> stream = DataStream.generator(data_generator) >>> for [a] in stream: ... print(a) [0 1 2] [3 4 5] or you may generate a tuple / list of arrays: >>> def data_generator(): ... for i in range(2): ... arr = np.arange(i * 3, (i + 1) * 3, dtype=np.int32) ... yield arr, arr ** 2 # or return [x + y, x * y] >>> stream = DataStream.generator(data_generator) >>> for [a, b] in stream: ... print(a, b) [0 1 2] [0 1 4] [3 4 5] [ 9 16 25] Transforming Data Streams ========================= A :class:`DataStream` instance can be transformed into another data stream. To select a subset of the arrays within each mini-batch, or re-order the arrays, you may: >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> z = np.arange(10, 15, dtype=np.int32) >>> # note we shall select [x, z, x] >>> stream = DataStream.arrays([x, y, z], batch_size=3).select([0, 2, 0]) >>> for [a, b, c] in stream: ... print(a, b, c) [0 1 2] [10 11 12] [0 1 2] [3 4] [13 14] [3 4] To transform the arrays within each mini-batch by a mapper function, you may: >>> def mapper(x, y): ... return x + y >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> stream = DataStream.arrays([x, y], batch_size=3).map(mapper) >>> for [a] in stream: ... print(a) [5 7 9] [11 13] or you may return a tuple / list of arrays: >>> def mapper(x, y): ... return x + y, x * y # or return [x + y, x * y] >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> stream = DataStream.arrays([x, y], batch_size=3).map(mapper) >>> for [a, b] in stream: ... print(a, b) [5 7 9] [ 0 6 14] [11 13] [24 36] To pre-fetch from a time-consuming data stream in background thread (which is necessary when using a slow mapper), you may: >>> stream = DataStream.int_seq(5, batch_size=3) >>> with stream.threaded(prefetch=2) as prefetch_stream: ... for [x] in prefetch_stream: ... print(x) [0 1 2] [3 4] """ def __init__(self, batch_size: Optional[int] = None, array_count: Optional[int] = None, data_shapes: Optional[Tuple[ArrayShape, ...]] = None, data_length: Optional[int] = None, random_state: Optional[np.random.RandomState] = None): """ Construct a :class:`DataStream`. Args: batch_size: The number of data within each mini-batch. array_count: The number of arrays within each mini-batch. data_shapes: The data shapes (excluding the batch axis). data_length: The total number of data. random_state: The NumPy random state instance. Raises: ValueError: If `len(data_shapes) != array_count`. >>> stream = DataStream(data_shapes=((), (3, 5)), array_count=3) Traceback (most recent call last): ... ValueError: len(data_shapes) != array_count: data_shapes ((), (3, 5)) vs array_count 3 """ if batch_size is not None: batch_size = int(batch_size) if array_count is not None: array_count = int(array_count) if data_shapes is not None: data_shapes = to_data_shapes(data_shapes) if array_count is None: array_count = len(data_shapes) elif array_count != len(data_shapes): raise ValueError(f'len(data_shapes) != array_count: ' f'data_shapes {data_shapes} vs ' f'array_count {array_count}') if data_length is not None: data_length = int(data_length) if random_state is not None and not \ isinstance(random_state, np.random.RandomState): raise TypeError(f'`random_state` is not np.random.RandomState: ' f'{random_state!r}') if data_length is not None and batch_size is not None: batch_count = int((data_length + batch_size - 1) // batch_size) else: batch_count = None self._batch_size = batch_size self._batch_count = batch_count self._array_count = array_count self._data_shapes = data_shapes self._data_length = data_length self._random_state = random_state self._active_iterator = None self._auto_close_iterator_warning_printed = False def __iter__(self) -> GeneratorIterator[ArrayTuple]: """ Iterate through the mini-batches. Note if a previous iterator is not closed before obtaining a new one, the previous iterator will be closed automatically, and a warning will be printed to the console (for only once). """ if self._active_iterator is not None: self._active_iterator.close() self._active_iterator = None if not self._auto_close_iterator_warning_printed: warnings.warn( f'Another iterator of the DataStream {self!r} is still ' f'active, will close it automatically. If you did not ' f'exhaust the iterator, remember to call `close()` on it.', UserWarning, ) self._auto_close_iterator_warning_printed = True def make_generator(): g = self._minibatch_iterator() try: yield from g finally: self._active_iterator = None self._active_iterator = GeneratorIterator(make_generator()) return self._active_iterator def __len__(self): """ Get the total number of data. If a data stream reports this number (i.e., being not None), then it equals to the sum of array lengths from all mini-batches in one epoch. >>> stream = DataStream.int_seq(5, batch_size=3) >>> len(stream) 5 >>> stream = DataStream.int_seq(5, batch_size=3, skip_incomplete=True) >>> len(stream) 3 Raises: RuntimeError: If a data stream cannot report this number, i.e., `data_length` is None. >>> def g(): ... yield np.arange(3) >>> stream = DataStream.generator(g) >>> stream.data_length is None True >>> len(stream) Traceback (most recent call last): ... RuntimeError: stream data length is not available """ ret = self.data_length if ret is None: raise RuntimeError(f'stream data length is not available') return ret @property def batch_size(self) -> Optional[int]: """ Get the batch size of this data stream. If a data stream reports this number (i.e., being not None), then the actual length of each mini-batch is guaranteed to be NO MORE THAN this. >>> x = np.random.normal(size=[5, 4]) >>> stream = DataStream.arrays([x], batch_size=3) >>> stream.batch_size 3 """ return self._batch_size @property def array_count(self) -> Optional[int]: """ Get the count of arrays within each mini-batch. >>> x = np.random.normal(size=[5, 4]) >>> y = np.random.normal(size=[5, 3, 2]) >>> stream = DataStream.arrays([x, y], batch_size=3) >>> stream.array_count 2 """ return self._array_count @property def data_shapes(self) -> Optional[Tuple[ArrayShape, ...]]: """ Get the data shapes. Data shapes are shapes if mini-batch array without the batch axis. >>> x = np.random.normal(size=[5, 4]) >>> y = np.random.normal(size=[5, 3, 2]) >>> stream = DataStream.arrays([x, y], batch_size=3) >>> stream.data_shapes ((4,), (3, 2)) """ return self._data_shapes @property def data_length(self) -> Optional[int]: """ Get the total number of data. If a data stream reports this number (i.e., being not None), then it equals to the sum of array lengths from all mini-batches in one epoch. >>> stream = DataStream.int_seq(5, batch_size=3) >>> stream.data_length 5 >>> stream = DataStream.int_seq(5, batch_size=3, skip_incomplete=True) >>> stream.data_length 3 """ return self._data_length @property def batch_count(self) -> Optional[int]: """ Get the total number of batches in an epoch. >>> stream = DataStream.int_seq(5, batch_size=3) >>> stream.batch_count 2 >>> stream = DataStream.int_seq(5, batch_size=3, skip_incomplete=True) >>> stream.batch_count 1 """ return self._batch_count @property def random_state(self) -> Optional[np.random.RandomState]: """Get the NumPy random state associated with this data stream.""" return self._random_state def copy(self, **kwargs): """ Get a copy of this data stream. You may override some of the construction arguments by specifying named arguments via :param:`kwargs`. However, some argument may not be overridable (depends on the implementation of subclasses). >>> x = np.arange(5, dtype=np.int32) >>> stream = DataStream.arrays([x], batch_size=3) >>> for [a] in stream: ... print(a) [0 1 2] [3 4] >>> stream2 = stream.copy(batch_size=4) >>> isinstance(stream2, ArraysDataStream) True >>> for [a] in stream2: ... print(a) [0 1 2 3] [4] Args: \\**kwargs: The overrided construction arguments. Returns: The copied data stream. """ raise NotImplementedError() def _copy_helper(self, attrs: Iterable[str], **kwargs): for attr in attrs: kwargs.setdefault(attr, getattr(self, attr)) return self.__class__(**kwargs) def _minibatch_iterator(self) -> Generator[ArrayTuple, None, None]: raise NotImplementedError() def get_arrays(self, max_batch: Optional[int] = None) -> Tuple[np.ndarray, ...]: """ Collecting mini-batches into NumPy arrays. >>> x = np.arange(0, 5, dtype=np.int32) >>> stream = DataStream.arrays([x], batch_size=3).map(lambda t: t ** 2) >>> arrays = stream.get_arrays() >>> len(arrays) 1 >>> print(arrays[0]) [ 0 1 4 9 16] >>> arrays = stream.get_arrays(max_batch=1) >>> len(arrays) 1 >>> print(arrays[0]) [0 1 4] >>> arrays = stream.get_arrays(max_batch=0) >>> len(arrays) 1 >>> print(arrays[0]) [] Args: max_batch: If specified, will take at most this number of batches. Returns: The collected arrays. Raises: RuntimeError: If this data-flow is empty. >>> def g(): ... if False: ... yield () >>> stream = DataStream.generator(g) >>> stream.get_arrays() Traceback (most recent call last): ... RuntimeError: empty data stream cannot be converted to arrays """ arrays_buf = [] g = iter(self) try: try: batch = next(g) except StopIteration: raise RuntimeError( 'empty data stream cannot be converted to arrays') try: arrays_buf = [[to_number_or_numpy(arr)] for arr in batch] batch_index = 1 while max_batch is None or batch_index < max_batch: batch = next(g) for i, arr in enumerate(batch): arrays_buf[i].append(to_number_or_numpy(arr)) batch_index += 1 if max_batch == 0: arrays_buf = [[array_buf[0][:0]] for array_buf in arrays_buf] except StopIteration: pass return tuple(np.concatenate(array_buf) for array_buf in arrays_buf) finally: g.close() def to_arrays_stream(self, batch_size: int = NOT_SET, shuffle: bool = False, skip_incomplete: bool = False, random_state: Optional[np.random.RandomState] = NOT_SET ) -> 'ArraysDataStream': """ Convert this data-flow to an arrays stream. By default, the original batch size will be preserved: >>> stream = DataStream.int_seq(5, batch_size=3).map(lambda x: x ** 2) >>> isinstance(stream, MapperDataStream) True >>> stream2 = stream.to_arrays_stream() >>> isinstance(stream2, ArraysDataStream) True >>> for [a] in stream2: ... print(a) [0 1 4] [ 9 16] You may also override the batch size: >>> stream3 = stream.to_arrays_stream(batch_size=4) >>> for [a] in stream3: ... print(a) [0 1 4 9] [16] Args: batch_size: The number of data within each mini-batch. If not specified, will use the original batch size if possible. shuffle: Whether or not to shuffle data? skip_incomplete: Whether or not to exclude the last mini-batch if it is incomplete? random_state : The NumPy random state instance. If not specified, will use the original random state instance. Returns: The constructed array stream. Raises: ValueError: If the batch size is neither specified, nor can it be determined according to the original batch size. >>> def g(): ... yield np.arange(3) >>> stream = DataStream.generator(g) >>> stream.to_arrays_stream() Traceback (most recent call last): ... ValueError: `batch_size` must be specified """ if batch_size is NOT_SET: batch_size = self.batch_size if batch_size is None: raise ValueError('`batch_size` must be specified') if random_state is NOT_SET: random_state = self.random_state return ArraysDataStream( self.get_arrays(), batch_size=batch_size, shuffle=shuffle, skip_incomplete=skip_incomplete, random_state=random_state ) # -------- here starts the factory methods -------- @staticmethod def arrays(arrays: Iterable[Array], batch_size: int, shuffle: bool = False, skip_incomplete: bool = False, random_state: Optional[np.random.RandomState] = None ) -> 'ArraysDataStream': """ Construct an arrays stream, i.e., :class:`ArraysDataStream`. >>> x = np.arange(5, dtype=np.int32) >>> y = x ** 2 >>> stream = DataStream.arrays([x, y], batch_size=3) >>> for [a, b] in stream: ... print(a, b) [0 1 2] [0 1 4] [3 4] [ 9 16] You may shuffle the data by setting `shuffle = True`: >>> np.random.seed(1234) >>> stream = DataStream.arrays([x, y], batch_size=3, shuffle=True) >>> for [a, b] in stream: ... print(a, b) [4 0 1] [16 0 1] [2 3] [4 9] You may discard the last incomplete mini-batch by setting `skip_incomplete = True`: >>> stream = DataStream.arrays( ... [x, y], batch_size=3, skip_incomplete=True) >>> for [a, b] in stream: ... print(a, b) [0 1 2] [0 1 4] Args: arrays: A sequence of numpy-like arrays. These arrays should be at least 1-d, and the size of the first axis must be identical. batch_size: The number of data within each mini-batch. shuffle: Whether or not to shuffle data? skip_incomplete: Whether or not to exclude the last mini-batch if it is incomplete? random_state: The numpy random state instance. Returns: The arrays stream. """ return ArraysDataStream( arrays=arrays, batch_size=batch_size, shuffle=shuffle, skip_incomplete=skip_incomplete, random_state=random_state ) @staticmethod def int_seq(start: int, stop: int = None, step: int = None, *, dtype=np.int32, batch_size: int = NOT_SET, shuffle: bool = False, skip_incomplete: bool = False, random_state: Optional[np.random.RandomState] = None ) -> 'IntSeqDataStream': """ Construct a integer sequence stream, i.e., :class:`IntSeqStream`. To construct various integer sequences: >>> stream = DataStream.int_seq(5, batch_size=3) >>> for [a] in stream: ... print(a) [0 1 2] [3 4] >>> stream = DataStream.int_seq(2, 11, 2, batch_size=3) >>> for [a] in stream: ... print(a) [2 4 6] [ 8 10] >>> stream = DataStream.int_seq(-5, step=-1, batch_size=3) >>> for [a] in stream: ... print(a) [ 0 -1 -2] [-3 -4] >>> stream = DataStream.int_seq(-2, -11, -2, batch_size=3) >>> for [a] in stream: ... print(a) [-2 -4 -6] [ -8 -10] You may shuffle the sequence by setting `shuffle = True`: >>> np.random.seed(1234) >>> stream = DataStream.int_seq(5, batch_size=3, shuffle=True) >>> for [a] in stream: ... print(a) [4 0 1] [2 3] You may discard the last incomplete mini-batch by setting `skip_incomplete = True`: >>> stream = DataStream.int_seq(5, batch_size=3, skip_incomplete=True) >>> for [a] in stream: ... print(a) [0 1 2] Args: start: If `stop` is specified, this is the starting number. Otherwise this is the ending number, and the starting number is 0. stop: The ending number. step: The sequence incremental step. dtype: The NumPy data type. batch_size: The number of data within each mini-batch. shuffle: Whether or not to shuffle data? skip_incomplete: Whether or not to exclude the last mini-batch if it is incomplete? random_state: The numpy random state instance. Returns: The integer sequence stream. """ return IntSeqDataStream( start=start, stop=stop, step=step, dtype=dtype, batch_size=batch_size, shuffle=shuffle, skip_incomplete=skip_incomplete, random_state=random_state, ) @staticmethod def gather(streams: Iterable['DataStream'], random_state: Optional[np.random.RandomState] = None ) -> 'GatherDataStream': return GatherDataStream(streams=streams, random_state=random_state) @staticmethod def generator(f: Callable[[], ArraysOrArrayGenerator] ) -> 'GeneratorFactoryDataStream': return GeneratorFactoryDataStream(f) # -------- here starts the transforming methods -------- def map(self, mapper: Callable[..., ArraysOrArray], preserve_shapes: bool = False ) -> 'MapperDataStream': """ Transform this data stream via a mapper function. To return a single array: >>> def mapper(x, y): ... return x + y >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> stream = DataStream.arrays([x, y], batch_size=3).map(mapper) >>> for [a] in stream: ... print(a) [5 7 9] [11 13] To return a tuple / list of arrays: >>> def mapper(x, y): ... return x + y, x * y # or return [x + y, x * y] >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> stream = DataStream.arrays([x, y], batch_size=3).map(mapper) >>> for [a, b] in stream: ... print(a, b) [5 7 9] [ 0 6 14] [11 13] [24 36] Args: mapper: The mapper function. preserve_shapes: User specified hint, whether or not the `mapper` preserves the array count and shapes within each mini-batch? This hint might benefit further transformation. By default :obj:`False`. >>> def mapper(x, y): ... return x ** 2, y - 1 >>> x = np.random.normal(size=[5, 4]) >>> y = np.random.normal(size=[5, 3, 2]) >>> stream = DataStream.arrays([x, y], batch_size=3) >>> stream.array_count, stream.data_shapes (2, ((4,), (3, 2))) >>> stream2 = stream.map(mapper) >>> stream2.array_count, stream2.data_shapes (None, None) >>> stream3 = stream.map(mapper, preserve_shapes=True) >>> stream3.array_count, stream3.data_shapes (2, ((4,), (3, 2))) Returns: The transformed data stream. """ return MapperDataStream( source=self, mapper=mapper, preserve_shapes=preserve_shapes) def threaded(self, prefetch: int = 5) -> 'ThreadingDataStream': """ Construct a data stream that prefetches this data stream in a background thread. >>> stream = DataStream.int_seq(5, batch_size=3) >>> with stream.threaded() as prefetch_stream: ... for [x] in prefetch_stream: ... print(x) [0 1 2] [3 4] Args: prefetch: Number of mini-batches to prefetch in background. Returns: The background data stream. """ return ThreadingDataStream(self, prefetch=prefetch) def select(self, indices: Iterable[int]) -> 'MapperDataStream': """ Construct a data stream that selects a subset of the arrays within each mini-batch, or re-order the arrays. Given the following source data stream: >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> z = np.arange(10, 15, dtype=np.int32) >>> source = DataStream.arrays([x, y, z], batch_size=3) We shall select [x, z, x] from source: >>> stream = source.select([0, 2, 0]) >>> for [a, b, c] in stream: ... print(a, b, c) [0 1 2] [10 11 12] [0 1 2] [3 4] [13 14] [3 4] The various data stream properties are also properly inherited: >>> x = np.random.normal(size=[5, 4]) >>> y = np.random.normal(size=[5, 2, 3]) >>> source = DataStream.arrays([x, y], batch_size=3) >>> stream = source.select([-1, 0, 1]) >>> stream.array_count 3 >>> stream.data_shapes ((2, 3), (4,), (2, 3)) >>> stream.data_length 5 Args: indices: The indices of the arrays to select within each mini-batch. Returns: The transformed data stream. Raises: IndexError: If `self.array_count` is reported, and any index in `indices` out of this range. >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> stream = DataStream.arrays([x, y], batch_size=3) >>> stream.select([0, 1, 2]) Traceback (most recent call last): ... IndexError: array index out of range Note if `self.array_count` is not reported (i.e., is None), then :class:`IndexError` will not be raised until iterated. >>> def mapper(x, y, z): ... return x + y, y - z >>> x = np.arange(0, 5, dtype=np.int32) >>> y = np.arange(5, 10, dtype=np.int32) >>> z = np.arange(10, 15, dtype=np.int32) >>> stream = DataStream.arrays([x, y, z], batch_size=3). \ map(mapper).select([0, 1, 2]) >>> for batch in stream: ... print(batch) Traceback (most recent call last): ... IndexError: tuple index out of range """ # validate the argument indices = tuple(indices) if self.array_count is not None: for i in indices: if i < -self.array_count or i >= self.array_count: raise IndexError(f'array index out of range') # prepare for the mapper def mapper(*arrays): return tuple(arrays[j] for j in indices) # construct the mapper data stream if self.data_shapes is not None: data_shapes = tuple(self.data_shapes[i] for i in indices) else: data_shapes = None array_count = len(indices) return MapperDataStream( source=self, mapper=mapper, data_shapes=data_shapes, array_count=array_count ) class ArraysDataStream(DataStream): """NumPy arrays data stream.""" def __init__(self, arrays: Iterable[Array], batch_size: int, shuffle: bool, skip_incomplete: bool, random_state: Optional[np.random.RandomState] = None): # validate parameters arrays = tuple(arrays) if not arrays: raise ValueError('`arrays` must not be empty.') for a in arrays: if not hasattr(a, 'shape'): raise ValueError('`arrays` must be arrays.') if len(a.shape) < 1: raise ValueError('`arrays` must be at least 1-d arrays.') data_shapes = to_data_shapes(arr.shape[1:] for arr in arrays) array_length = len(arrays[0]) for a in arrays[1:]: if len(a) != array_length: raise ValueError('`arrays` must have the same length.') if skip_incomplete: data_length = array_length // batch_size * batch_size else: data_length = array_length # construct the instance super().__init__( batch_size=batch_size, array_count=len(data_shapes), data_shapes=data_shapes, data_length=data_length, random_state=random_state, ) self._arrays = map_to_tuple(to_readonly_array, arrays) self._indices_buffer = None # type: Array self._shuffle = bool(shuffle) self._skip_incomplete = bool(skip_incomplete) @property def the_arrays(self): """Get the underlying NumPy arrays without copy.""" return self._arrays @property def shuffle(self) -> bool: """Whether or not to shuffle data?""" return self._shuffle @property def skip_incomplete(self) -> bool: """Whether or not to exclude the last mini-batch if it is incomplete?""" return self._skip_incomplete def _minibatch_iterator(self) -> Generator[ArrayTuple, None, None]: # shuffle the source arrays if necessary if self.shuffle: if self._indices_buffer is None: indices_count = len(self._arrays[0]) t = np.int32 if indices_count < (1 << 31) else np.int64 self._indices_buffer = np.arange(indices_count, dtype=t) rng = self._random_state or np.random rng.shuffle(self._indices_buffer) def get_slice(s): return tuple( a[self._indices_buffer[s]] for a in self.the_arrays ) else: def get_slice(s): return tuple(a[s] for a in self.the_arrays) # now iterator through the mini-batches for batch_s in minibatch_slices_iterator( length=self.data_length, batch_size=self.batch_size, skip_incomplete=self.skip_incomplete): yield get_slice(batch_s) def copy(self, **kwargs): return self._copy_helper( ('batch_size', 'shuffle', 'skip_incomplete', 'random_state'), arrays=self._arrays, **kwargs ) class IntSeqDataStream(DataStream): """Integer sequence data stream.""" def __init__(self, start: int, stop: int = None, step: int = None, *, dtype=np.int32, batch_size: int = NOT_SET, shuffle: bool = False, skip_incomplete: bool = False, random_state: Optional[np.random.RandomState] = None): # validate the arguments start = int(start) if stop is None: stop = start start = 0 else: stop = int(stop) if step is None: step = 1 else: step = int(step) dtype = np.dtype(dtype) if batch_size is NOT_SET: raise ValueError('`batch_size` is required.') # construct the int sequence seq = np.arange(start=start, stop=stop, step=step, dtype=dtype) if skip_incomplete: data_length = len(seq) // batch_size * batch_size else: data_length = len(seq) # construct the instance super().__init__( batch_size=batch_size, array_count=1, data_shapes=((),), data_length=data_length, random_state=random_state, ) self._start = start self._stop = stop self._step = step self._dtype = dtype self._seq = seq self._shuffle = bool(shuffle) self._skip_incomplete = bool(skip_incomplete) @property def start(self) -> int: """Get the starting number.""" return self._start @property def stop(self) -> int: """Get the ending number.""" return self._stop @property def step(self) -> int: """Get the sequence incremental step.""" return self._step @property def dtype(self) -> np.dtype: """Get the NumPy data type.""" return self._dtype @property def shuffle(self) -> bool: """Whether or not to shuffle data?""" return self._shuffle @property def skip_incomplete(self) -> bool: """Whether or not to exclude the last mini-batch if it is incomplete?""" return self._skip_incomplete def _minibatch_iterator(self): if self.shuffle: rng = self._random_state or np.random rng.shuffle(self._seq) for batch_s in minibatch_slices_iterator( length=self.data_length, batch_size=self.batch_size, skip_incomplete=self.skip_incomplete): yield (to_readonly_array(self._seq[batch_s]),) def copy(self, **kwargs): return self._copy_helper( ('dtype', 'batch_size', 'shuffle', 'skip_incomplete', 'random_state'), start=self.start, stop=self.stop, step=self.step, **kwargs ) class UserGeneratorDataStream(DataStream): """Base class for data streams with user generated data.""" def _validate_batch(self, batch): batch = ensure_batch_is_tuple(batch) if self.batch_size is not None and batch: batch_size = len(batch[0]) if batch_size > self.batch_size: raise ValueError( f'batch size of the mapper output is not ' f'valid: expected <= {self.batch_size}, ' f'got {batch_size}' ) for i, b in enumerate(batch[1:], 1): if len(b) != batch_size: raise ValueError( f'batch size of the {i}-th mapper output != ' f'the first output' ) if self.array_count is not None and len(batch) != self.array_count: raise ValueError(f'user generator returned invalid number of ' f'arrays: expected {self.array_count}, got ' f'{len(batch)}') if self.data_shapes is not None: for i, (x, y) in enumerate(zip(batch, self.data_shapes)): if x.shape[1:] != y: raise ValueError( f'data shape of the {i}-th mapper output is not ' f'valid: expected {y}, got {x.shape[1:]}' ) return batch class GeneratorFactoryDataStream(UserGeneratorDataStream): """Data stream that turns a generator factory function into a stream.""" def __init__(self, factory: Callable[[], ArraysOrArrayGenerator]): super().__init__() self._factory = factory @property def factory(self) -> Callable[[], Generator[Sequence[Array], None, None]]: """ Get the generator factory function (i.e., function that returns a mini-batch arrays generator). """ return self._factory def _minibatch_iterator(self): g = self._factory() try: for batch in g: yield self._validate_batch(batch) finally: if hasattr(g, 'close'): # pragma: no cover g.close() def copy(self, **kwargs): return self._copy_helper((), factory=self.factory, **kwargs) class GatherDataStream(DataStream): """Data stream that gathers multiple streams into one.""" def __init__(self, streams: Iterable[DataStream], random_state: Optional[np.random.RandomState] = NOT_SET): # validate the streams streams = tuple(streams) if not streams: raise ValueError('At least one data stream should be specified.') for i, stream in enumerate(streams): if not isinstance(stream, DataStream): raise TypeError(f'The {i}-th element of `streams` is not an ' f'instance of DataStream: {stream}.') # inspect the properties of the data streams batch_size = NOT_SET array_count = 0 data_shapes = [] data_length = NOT_SET for i, stream in enumerate(streams): # check the batch size if stream.batch_size is not None: if batch_size is NOT_SET: batch_size = stream.batch_size elif batch_size != stream.batch_size: raise ValueError( f'Inconsistent batch size among the specified streams: ' f'encountered {stream.batch_size} at the {i}-th ' f'stream, but has already encountered {batch_size} ' f'before.' ) # check the array count if array_count is not None: if stream.array_count is not None: array_count += stream.array_count else: array_count = None # check the data shapes if data_shapes is not None: if stream.data_shapes is not None: data_shapes.extend(stream.data_shapes) else: data_shapes = None # check the data length if stream.data_length is not None: if data_length is NOT_SET: data_length = stream.data_length elif data_length != stream.data_length: raise ValueError( f'Inconsistent data length among the specified ' f'streams: encountered {stream.data_length} at ' f'the {i}-th stream, but has already encountered ' f'{data_length} before.' ) # check the random state if stream.random_state is not None and random_state is NOT_SET: random_state = stream.random_state if batch_size is NOT_SET: batch_size = None if data_shapes is not None: data_shapes = tuple(data_shapes) if data_length is NOT_SET: data_length = None if random_state is NOT_SET: random_state = None # construct the instance super().__init__( batch_size=batch_size, array_count=array_count, data_shapes=data_shapes, data_length=data_length, random_state=random_state ) self._streams = streams @property def streams(self) -> Tuple[DataStream, ...]: """Get the gathered data streams.""" return self._streams def _minibatch_iterator(self): iterators = [iter(s) for s in self._streams] try: for batches in zip(*iterators): yield sum([tuple(b) for b in batches], ()) finally: for i in iterators: if hasattr(i, 'close'): # pragma: no cover i.close() def copy(self, **kwargs): return self._copy_helper(('random_state',), streams=self.streams, **kwargs) class MapperDataStream(UserGeneratorDataStream): """Data stream that transforms the source stream via a mapper function.""" def __init__(self, source: DataStream, mapper: Callable[..., ArraysOrArray], batch_size: Optional[int] = NOT_SET, array_count: Optional[int] = NOT_SET, data_shapes: Optional[Tuple[ArrayShape, ...]] = NOT_SET, data_length: Optional[int] = NOT_SET, random_state: Optional[np.random.RandomState] = NOT_SET, preserve_shapes: bool = False): # validate the arguments if not isinstance(source, DataStream): raise TypeError(f'`source` is not a DataStream: {source!r}') if batch_size is NOT_SET: batch_size = source.batch_size if array_count is NOT_SET: if preserve_shapes: array_count = source.array_count else: array_count = None if data_shapes is NOT_SET: if preserve_shapes: data_shapes = source.data_shapes else: data_shapes = None if data_length is NOT_SET: data_length = source.data_length if random_state is NOT_SET: random_state = source.random_state super().__init__( batch_size=batch_size, array_count=array_count, data_shapes=data_shapes, data_length=data_length, random_state=random_state ) self._source = source self._mapper = mapper @property def source(self) -> DataStream: """Get the source data stream.""" return self._source def _minibatch_iterator(self): g = iter(self._source) try: for batch in g: yield self._validate_batch( self._mapper(*ensure_batch_is_tuple(batch))) finally: g.close() def copy(self, **kwargs): return self._copy_helper( ('batch_size', 'array_count', 'data_shapes', 'data_length', 'random_state'), source=self._source, mapper=self._mapper, **kwargs ) class ThreadingDataStream(DataStream, AutoInitAndCloseable): """ Data stream that prefetches mini-batches from the source stream in a background thread. """ EPOCH_END = object() """Object to mark an ending position of an epoch.""" class ErrorBox(object): """Class to carry an error.""" def __init__(self, error): self.error = error def __init__(self, source: DataStream, prefetch: int): # validate the parameters if not isinstance(source, DataStream): raise TypeError(f'`source` is not a DataStream: {source!r}') prefetch = int(prefetch) if prefetch < 1: raise ValueError('`prefetch` must be at least 1') # construct the instance super().__init__( batch_size=source.batch_size, array_count=source.array_count, data_shapes=source.data_shapes, data_length=source.data_length, random_state=source.random_state, ) self._source = source self._prefetch = prefetch # internal states for background worker self._worker = None # type: Thread self._batch_queue = None # type: Queue self._epoch_counter = None # counter for tracking the active epoch self._stopping = None self._worker_alive = None self._worker_ready_sem = None @property def source(self) -> DataStream: """Get the source data stream.""" return self._source @property def prefetch(self) -> int: """Get the number of mini-batches to prefetch in background.""" return self._prefetch def _worker_func(self): active_epoch = self._epoch_counter self._worker_alive = True self._worker_ready_sem.release() try: while not self._stopping: # iterate through the mini-batches in the current epoch g = iter(self.source) try: for batch in g: if self._stopping or active_epoch < self._epoch_counter: break self._batch_queue.put((active_epoch, batch)) finally: g.close() # put the epoch ending mark into the queue if not self._stopping: self._batch_queue.put((active_epoch, self.EPOCH_END)) # move to the next epoch active_epoch += 1 except Exception as ex: # pragma: no cover getLogger(__name__).warning( f'{self.__class__.__qualname__} exited because of error', exc_info=True ) self._batch_queue.put((active_epoch, self.ErrorBox(ex))) raise finally: self._worker_alive = False def _init(self): # prepare for the worker states self._batch_queue = Queue(self.prefetch) self._epoch_counter = 0 self._stopping = False self._worker_ready_sem = Semaphore(value=0) # create and start the worker self._worker = Thread(target=self._worker_func) self._worker.daemon = True self._worker.start() # wait for the thread to show up self._worker_ready_sem.acquire() def _close(self): try: # prevent the worker thread from further work self._stopping = True # exhaust all remaining queue items to notify the background worker while not self._batch_queue.empty(): self._batch_queue.get() # wait until the worker exit self._worker.join() finally: self._worker = None self._batch_queue = None self._worker_ready_sem = None self._initialized = False def _minibatch_iterator(self): self.init() try: # iterate through one epoch while self._worker_alive or not self._batch_queue.empty(): epoch, payload = self._batch_queue.get() if epoch < self._epoch_counter: # we've got a remaining item from the last epoch, skip it pass elif epoch > self._epoch_counter: # pragma: no cover # we've accidentally got an item from the future epoch # it should be a bug, and we shall report it raise RuntimeError('Unexpected entry from future epoch.') elif payload is self.EPOCH_END: # we've got the epoch ending mark for the current epoch, # so we should break the loop break elif isinstance(payload, self.ErrorBox): # we've got an error, re-raise it self.close() raise payload.error else: # we've got a normal batch for the current epoch, # so yield it yield payload finally: self._epoch_counter += 1 def copy(self, **kwargs): return self._copy_helper(('prefetch',), source=self.source, **kwargs)