Skip to content

Commit 7ec4fc4

Browse files
WIP
1 parent 1ef521f commit 7ec4fc4

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

extra_keras_datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import absolute_import
22

33
from . import emnist
4-
from . import kmnist
4+
from . import kmnist
5+
from . import svhn

extra_keras_datasets/svhn.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
'''
2+
Import the SVHN dataset
3+
Source: http://ufldl.stanford.edu/housenumbers/
4+
Description: Street View House Numbers
5+
6+
~~~ Important note ~~~
7+
Please cite the following paper when using or referencing the dataset:
8+
Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, Andrew Y. Ng Reading Digits in Natural Images with Unsupervised Feature Learning NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011. Retrieved from http://ufldl.stanford.edu/housenumbers/nips2011_housenumbers.pdf
9+
10+
'''
11+
12+
from keras.utils.data_utils import get_file
13+
import numpy as np
14+
from zipfile import ZipFile
15+
from scipy import io as sio
16+
import os
17+
18+
def load_data(path='svhn_matlab.npz', type='normal'):
19+
"""Loads the SVHN dataset.
20+
# Arguments
21+
path: path where to cache the dataset locally
22+
(relative to ~/.keras/datasets).
23+
type: any of normal, extra (extra appends ~530K extra images for training)
24+
# Returns
25+
Tuple of Numpy arrays: `(input_train, target_train), (input_test, target_test)`.
26+
"""
27+
path_train = get_file(f'{path}_train',
28+
origin='http://ufldl.stanford.edu/housenumbers/train_32x32.mat')
29+
path_test = get_file(f'{path}_test',
30+
origin='http://ufldl.stanford.edu/housenumbers/test_32x32.mat')
31+
32+
# Load data from Matlab file.
33+
# Source: https://stackoverflow.com/a/53547262
34+
mat_train = sio.loadmat(path_train)
35+
mat_test = sio.loadmat(path_test)
36+
37+
# Prepare training data
38+
input_train = mat_train['X']
39+
input_train = np.rollaxis(input_train, 3, 0)
40+
target_train = mat_train['y'].flatten()
41+
42+
# Prepare testing data
43+
input_test = mat_test['X']
44+
input_test = np.rollaxis(input_test, 3, 0)
45+
target_test = mat_test['y'].flatten()
46+
47+
# Append extra data, if required
48+
if type == 'extra':
49+
path_extra = get_file(f'{path}_extra',
50+
origin='http://ufldl.stanford.edu/housenumbers/extra_32x32.mat')
51+
mat_extra = sio.loadmat(path_extra)
52+
input_extra = mat_extra['X']
53+
input_extra = np.rollaxis(input_extra, 3, 0)
54+
target_extra = mat_extra['y'].flatten()
55+
input_train = np.insert(input_extra, input_train)
56+
target_train = np.concatenate(target_extra, target_train)
57+
58+
# Return data
59+
return (input_train, target_train), (input_test, target_test)

0 commit comments

Comments
 (0)