""" Simple dataset loaders. For more datasets and more comprehensive loaders, you may turn to dedicated libraries like `fuel`. """ import gzip import hashlib import os import pickle from typing import * import idx2numpy import numpy as np from ..typing_ import * from ..utils import CacheDir, validate_enum_arg __all__ = ['load_mnist', 'load_fashion_mnist', 'load_cifar10', 'load_cifar100'] _MNIST_LIKE_FILE_NAMES = { 'train_x': 'train-images-idx3-ubyte.gz', 'train_y': 'train-labels-idx1-ubyte.gz', 'test_x': 't10k-images-idx3-ubyte.gz', 'test_y': 't10k-labels-idx1-ubyte.gz', } _MNIST_URI_PREFIX = 'http://yann.lecun.com/exdb/mnist/' _MNIST_FILE_MD5 = { 'train_x': 'f68b3c2dcbeaaa9fbdd348bbdeb94873', 'train_y': 'd53e105ee54ea40749a09fcbcd1e9432', 'test_x': '9fb629c4189551a2d022fa330f9573f3', 'test_y': 'ec29112dd5afa0611ce80d1b7f02629c', } _FASHION_MNIST_URI_PREFIX = 'http://fashion-mnist.s3-website.eu-central-1.' \ 'amazonaws.com/' _FASHION_MNIST_FILE_MD5 = { 'train_x': '8d4fb7e6c68d591d4c3dfef9ec88bf0d', 'train_y': '25c81989df183df01b3e8a0aad5dffbe', 'test_x': 'bef4ecab320f06d8554ea6380940ec79', 'test_y': 'bb300cfdad3c16e7a12a480ee83cd310', } _CIFAR_10_URI = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' _CIFAR_10_MD5 = 'c58f30108f718f92721af3b95e74349a' _CIFAR_10_CONTENT_DIR = 'cifar-10-batches-py' _CIFAR_100_URI = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' _CIFAR_100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' _CIFAR_100_CONTENT_DIR = 'cifar-100-python' def _validate_x_shape(shape, default_shape): shape = tuple(int(v) for v in shape) default_shape = tuple(int(v) for v in default_shape) value_size = int(np.prod(default_shape)) if np.prod(shape) != value_size: raise ValueError(f'`x_shape` does not product to {value_size}: {shape}') return shape def load_mnist_like(uri_prefix: str, file_md5: Dict[str, str], cache_name: str, x_shape: Sequence[int] = (28, 28), x_dtype: ArrayDType = np.uint8, y_dtype: ArrayDType = np.int32 ) -> Tuple[XYArrayTuple, XYArrayTuple]: """ Load an MNIST-like dataset as NumPy arrays. Args: uri_prefix: Common prefix of the URIs in `remote_files`. file_md5: The remote file MD5 hash sums, a dict of `{'train_x': ..., 'train_y': ..., 'test_x': ..., 'test_y': ...}`, where each value is the md5 sum. cache_name: Name of the cache directory. x_shape: Reshape each digit into this shape. x_dtype: Cast each digit into this data type. y_dtype: Cast each label into this data type. Returns: The ``(train_x, train_y), (test_x, test_y)`` arrays. """ def _fetch_array(array_name): uri = uri_prefix + _MNIST_LIKE_FILE_NAMES[array_name] md5 = file_md5[array_name] path = CacheDir(cache_name).download( uri, hasher=hashlib.md5(), expected_hash=md5) with gzip.open(path, 'rb') as f: return idx2numpy.convert_from_file(f) # check arguments x_shape = _validate_x_shape(x_shape, (28, 28)) # load data train_x = _fetch_array('train_x').astype(x_dtype) train_y = _fetch_array('train_y').astype(y_dtype) test_x = _fetch_array('test_x').astype(x_dtype) test_y = _fetch_array('test_y').astype(y_dtype) assert(len(train_x) == len(train_y) == 60000) assert(len(test_x) == len(test_y) == 10000) # change shape train_x = train_x.reshape([len(train_x)] + list(x_shape)) test_x = test_x.reshape([len(test_x)] + list(x_shape)) return (train_x, train_y), (test_x, test_y) def load_mnist(x_shape: Sequence[int] = (28, 28), x_dtype: ArrayDType = np.uint8, y_dtype: ArrayDType = np.int32 ) -> Tuple[XYArrayTuple, XYArrayTuple]: """ Load an MNIST dataset as NumPy arrays. Args: x_shape: Reshape each digit into this shape. x_dtype: Cast each digit into this data type. y_dtype: Cast each label into this data type. Returns: The ``(train_x, train_y), (test_x, test_y)`` arrays. """ return load_mnist_like( _MNIST_URI_PREFIX, _MNIST_FILE_MD5, 'mnist', x_shape, x_dtype, y_dtype) def load_fashion_mnist(x_shape: Sequence[int] = (28, 28), x_dtype: ArrayDType = np.uint8, y_dtype: ArrayDType = np.int32 ) -> Tuple[XYArrayTuple, XYArrayTuple]: """ Load an MNIST dataset as NumPy arrays. Args: x_shape: Reshape each digit into this shape. x_dtype: Cast each digit into this data type. y_dtype: Cast each label into this data type. Returns: The ``(train_x, train_y), (test_x, test_y)`` arrays. """ return load_mnist_like( _FASHION_MNIST_URI_PREFIX, _FASHION_MNIST_FILE_MD5, 'fashion_mnist', x_shape, x_dtype, y_dtype) def _cifar_load_batch(path, x_shape, x_dtype, y_dtype, expected_batch_label, labels_key='labels'): # load from file with open(path, 'rb') as f: d = { k.decode('utf-8'): v for k, v in pickle.load(f, encoding='bytes').items() } d['batch_label'] = d['batch_label'].decode('utf-8') assert(d['batch_label'] == expected_batch_label) data = np.asarray(d['data'], dtype=x_dtype) labels = np.asarray(d[labels_key], dtype=y_dtype) # change shape data = data.reshape((data.shape[0], 3, 32, 32)) data = np.transpose(data, (0, 2, 3, 1)) if x_shape: data = data.reshape([data.shape[0]] + list(x_shape)) return data, labels def load_cifar10(x_shape: Sequence[int] = (32, 32, 3), x_dtype: ArrayDType = np.float32, y_dtype: ArrayDType = np.int32) -> Tuple[XYArrayTuple, XYArrayTuple]: """ Load the CIFAR-10 dataset as NumPy arrays. Args: x_shape: Reshape each digit into this shape. x_dtype: Cast each digit into this data type. y_dtype: Cast each label into this data type. Returns: The ``(train_x, train_y), (test_x, test_y)`` arrays. """ # check the arguments x_shape = _validate_x_shape(x_shape, (32, 32, 3)) # fetch data path = CacheDir('cifar').download_and_extract( _CIFAR_10_URI, hasher=hashlib.md5(), expected_hash=_CIFAR_10_MD5) data_dir = os.path.join(path, _CIFAR_10_CONTENT_DIR) # load the data train_num = 50000 train_x = np.zeros((train_num,) + x_shape, dtype=x_dtype) train_y = np.zeros((train_num,), dtype=y_dtype) for i in range(1, 6): path = os.path.join(data_dir, 'data_batch_{}'.format(i)) x, y = _cifar_load_batch( path, x_shape=x_shape, x_dtype=x_dtype, y_dtype=y_dtype, expected_batch_label='training batch {} of 5'.format(i) ) (train_x[(i - 1) * 10000: i * 10000, ...], train_y[(i - 1) * 10000: i * 10000]) = x, y path = os.path.join(data_dir, 'test_batch') test_x, test_y = _cifar_load_batch( path, x_shape=x_shape, x_dtype=x_dtype, y_dtype=y_dtype, expected_batch_label='testing batch 1 of 1' ) assert(len(test_x) == len(test_y) == 10000) return (train_x, train_y), (test_x, test_y) def load_cifar100(label_mode: str = 'fine', x_shape: Sequence[int] = (32, 32, 3), x_dtype: ArrayDType = np.float32, y_dtype: ArrayDType = np.int32) -> Tuple[XYArrayTuple, XYArrayTuple]: """ Load the CIFAR-100 dataset as NumPy arrays. Args: label_mode: One of {"fine", "coarse"}. x_shape: Reshape each digit into this shape. x_dtype: Cast each digit into this data type. y_dtype: Cast each label into this data type. Returns: The ``(train_x, train_y), (test_x, test_y)`` arrays. """ # check the arguments label_mode = validate_enum_arg('label_mode', label_mode, ('fine', 'coarse')) x_shape = _validate_x_shape(x_shape, (32, 32, 3)) # fetch data path = CacheDir('cifar').download_and_extract( _CIFAR_100_URI, hasher=hashlib.md5(), expected_hash=_CIFAR_100_MD5) data_dir = os.path.join(path, _CIFAR_100_CONTENT_DIR) # load the data path = os.path.join(data_dir, 'train') train_x, train_y = _cifar_load_batch( path, x_shape=x_shape, x_dtype=x_dtype, y_dtype=y_dtype, expected_batch_label='training batch 1 of 1', labels_key='{}_labels'.format(label_mode) ) assert(len(train_x) == len(train_y) == 50000) path = os.path.join(data_dir, 'test') test_x, test_y = _cifar_load_batch( path, x_shape=x_shape, x_dtype=x_dtype, y_dtype=y_dtype, expected_batch_label='testing batch 1 of 1', labels_key='{}_labels'.format(label_mode) ) assert(len(test_x) == len(test_y) == 10000) return (train_x, train_y), (test_x, test_y)