Commit 72d0d0dc authored by Amelie Royer's avatar Amelie Royer 🐼

Reuse MNIST code in SVHN

parent 84880fe3
......@@ -9,9 +9,11 @@ import scipy.io
import tensorflow as tf
from .tfrecords_utils import *
from .mnist import MNISTLoader, MNISTFeatures
class SVHNConverter(Converter):
features = MNISTFeatures
def __init__(self, data_dir):
"""Initialize the object for the SVHN dataset in `data_dir`"""
......@@ -24,7 +26,7 @@ class SVHNConverter(Converter):
else:
self.data.append((name, data))
def convert(self, tfrecords_path, sort=False):
def convert(self, tfrecords_path, compression_type=None, sort=False):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`"""
for name, data in self.data:
# Load
......@@ -33,7 +35,7 @@ class SVHNConverter(Converter):
num_items = labels.shape[0]
# Write
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
writer = self.init_writer(writer_path, compression_type=compression_type)
labels_order = np.argsort(labels, axis=0) if sort else range(num_items)
for x, index in enumerate(labels_order):
print('\rLoad %s: %d / %d' % (name, x + 1, num_items), end='')
......@@ -41,35 +43,13 @@ class SVHNConverter(Converter):
img = img.astype(np.uint8)
class_id = int(labels[index, 0])
class_id = 0 if class_id == 10 else class_id
example = tf.train.Example(features=tf.train.Features(feature={
'class': int64_feature([class_id]),
'image': bytes_feature([img.tostring()]),
'id': int64_feature([index])}))
writer.write(example.SerializeToString())
writer.write(self.create_example_proto([class_id], [img.tostring()], [index]))
# End
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
print()
class SVHNLoader():
def __init__(self, image_size=None, verbose=False):
"""Init a Loader object. Loaded images will be resized to size `resize`."""
self.image_size = image_size
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'class': tf.FixedLenFeature((), tf.int64),
'image': tf.FixedLenFeature((), tf.string),
'id': tf.FixedLenFeature((), tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
image = decode_raw_image(parsed_features['image'], (32, 32, 3), image_size=self.image_size)
image = tf.identity(image, name='image')
class_id = tf.to_int32(parsed_features['class'], name='class_label')
index = tf.to_int32(parsed_features['id'], name='index')
# Return
records_dict = {'image': image, 'class': class_id, 'id': index}
if self.verbose: print_records(records_dict)
return records_dict
\ No newline at end of file
SVHNLoader = MNISTLoader
SVHNLoader.shape = (32, 32, 3)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment