Skip to content

Commit 9d8b29f

Browse files
Merge branch 'iris'
2 parents 244507e + 879d768 commit 9d8b29f

File tree

5 files changed

+100
-1
lines changed

5 files changed

+100
-1
lines changed

MANIFEST

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ setup.py
44
extra_keras_datasets\__init__.py
55
extra_keras_datasets\emnist.py
66
extra_keras_datasets\kmnist.py
7+
extra_keras_datasets\stl10.py
78
extra_keras_datasets\svhn.py

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Hi there, and welcome to the `extra-keras-datasets` module! This extension to th
2323
* [SVHN-Normal](#svhn-normal)
2424
* [SVHN-Extra](#svhn-extra)
2525
* [STL-10](#stl-10)
26+
* [Iris](#iris)
2627
- [Contributors and other references](#contributors-and-other-references)
2728
- [License](#license)
2829

@@ -167,6 +168,20 @@ from extra-keras-datasets import stl10
167168

168169
---
169170

171+
### Iris
172+
This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant. One class is linearly separable from the other 2; the latter are NOT linearly separable from each other.
173+
174+
Predicted attribute: class of iris plant.
175+
176+
```
177+
from extra-keras-datasets import iris
178+
(input_train, target_train), (input_test, target_test) = iris.load_data(test_split=0.2)
179+
```
180+
181+
<a href="./assets/iris.png"><img src="./assets/iris.png" width="100%" style="border: 3px solid #f6f8fa;" /></a>
182+
183+
---
184+
170185
## Contributors and other references
171186
* **EMNIST dataset:**
172187
* Cohen, G., Afshar, S., Tapson, J., & van Schaik, A. (2017). EMNIST: an extension of MNIST to handwritten letters. Retrieved from http://arxiv.org/abs/1702.05373

assets/iris.png

38.9 KB
Loading

extra_keras_datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from . import emnist
44
from . import kmnist
55
from . import svhn
6-
from . import stl10
6+
from . import stl10
7+
from . import iris

extra_keras_datasets/iris.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
'''
2+
Import the Iris dataset
3+
Source: http://archive.ics.uci.edu/ml/datasets/Iris
4+
Description: The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
5+
6+
~~~ Important note ~~~
7+
Please cite the following paper when using or referencing the dataset:
8+
Fisher,R.A. "The use of multiple measurements in taxonomic problems" Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to Mathematical Statistics" (John Wiley, NY, 1950).
9+
'''
10+
11+
from keras.utils.data_utils import get_file
12+
import numpy as np
13+
import math
14+
15+
def load_data(path='iris.npz', test_split=0.2):
16+
'''Loads the Iris dataset.
17+
# Arguments
18+
path: path where to cache the dataset locally
19+
(relative to ~/.keras/datasets).
20+
test_split: percentage of data to use for testing (by default 20%)
21+
# Returns
22+
Tuple of Numpy arrays: `(input_train, target_train), (input_test, target_test)`.
23+
Input structure: (sepal length, sepal width, petal length, petal width)
24+
Target structure: 0 = iris setosa; 1 = iris versicolor; 2 = iris virginica.
25+
'''
26+
path = get_file(path,
27+
origin='http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')
28+
29+
# Read data from file
30+
f = open(path, 'r')
31+
lines = f.readlines()
32+
33+
# Process each line into input/target structure
34+
samples = []
35+
for line in lines:
36+
sample = line_to_list(line)
37+
if sample is not None:
38+
samples.append(sample)
39+
f.close()
40+
41+
# Randomly shuffle the data
42+
np.random.shuffle(samples)
43+
44+
# Compute test_split in length
45+
num_test_samples = math.floor(len(samples) * test_split)
46+
47+
# Split data
48+
training_data = samples[num_test_samples:]
49+
testing_data = samples[:num_test_samples]
50+
51+
# Split into inputs and targets
52+
input_train = [i[0:4] for i in training_data]
53+
input_test = [i[0:4] for i in testing_data]
54+
target_train = [i[4] for i in training_data]
55+
target_test = [i[4] for i in testing_data]
56+
57+
# Return data
58+
return (input_train, target_train), (input_test, target_test)
59+
60+
def line_to_list(line):
61+
'''
62+
Convert a String-based line into a list with input and target data.
63+
'''
64+
elements = line.split(',')
65+
if len(elements) > 1:
66+
target = target_string_to_int(elements[4])
67+
full_sample = [float(i) for i in elements[0:4]]
68+
full_sample.append(target)
69+
return tuple(full_sample)
70+
else:
71+
return None
72+
73+
def target_string_to_int(target_value):
74+
'''
75+
Convert a String-based into an Integer-based target value.
76+
'''
77+
if target_value == 'Iris-setosa\n':
78+
return 0
79+
elif target_value == 'Iris-versicolor\n':
80+
return 1
81+
else:
82+
return 2

0 commit comments

Comments
 (0)