Commit 84880fe3 authored by Amelie Royer's avatar Amelie Royer 🐼

Reuse MNIST code in MNIST-M

parent eb3cc52b
......@@ -5,23 +5,25 @@ from __future__ import print_function
##############################
import os
import numpy as np
from matplotlib import image as mpimg
import tensorflow as tf
from .tfrecords_utils import *
from .tfrecords_utils import *
from .mnist import MNISTLoader, MNISTFeatures
class MNISTMConverter(Converter):
features = MNISTFeatures
def __init__(self, data_dir):
"""Initialize the object for the MNIST-M dataset in `data_dir`"""
print('Loading original MNIST-M data from', data_dir)
self.data_dir = data_dir
self.data = []
for name in ['train', 'test']:
split = os.path.join(data_dir, 'mnist_m_%s_labels.txt' % name)
image_dir = os.path.join(self.data_dir, 'mnist_m_%s' % name)
if not os.path.isfile(split):
image_dir = os.path.join(data_dir, 'mnist_m_%s' % name)
if not os.path.isfile(split):
print('Warning: Missing %s data' % name)
elif not os.path.exists(image_dir):
print('Warning: Missing %s image directory' % name)
......@@ -31,47 +33,26 @@ class MNISTMConverter(Converter):
labels = list(map(int, labels))
self.data.append((name, image_dir, images, labels))
def convert(self, tfrecords_path, sort=False):
def convert(self, tfrecords_path, compression_type=None, sort=False):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`"""
for name, image_dir, images, labels in self.data:
num_items = len(labels)
# Init writer
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
writer = self.init_writer(writer_path, compression_type=compression_type)
num_items = len(labels)
labels_order = np.argsort(labels, axis=0) if sort else range(num_items)
for x, index in enumerate(labels_order):
print('\rLoad %s: %d / %d' % (name, x + 1, num_items), end='')
img = np.ceil(255. * mpimg.imread(os.path.join(image_dir, images[index])))
img = img.astype(np.uint8)
class_id = labels[index]
example = tf.train.Example(features=tf.train.Features(feature={
'class': int64_feature([class_id]),
'image': bytes_feature([img.tostring()]),
'id': int64_feature([index])}))
writer.write(example.SerializeToString())
writer.write(self.create_example_proto([class_id], [img.tostring()], [index]))
# End writing
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
print()
class MNISTMLoader():
def __init__(self, image_size=None, verbose=False):
"""Init a Loader object. Loaded images will be resized to size `resize`."""
self.image_size = image_size
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'class': tf.FixedLenFeature((), tf.int64),
'image': tf.FixedLenFeature((), tf.string),
'id': tf.FixedLenFeature((), tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
image = decode_raw_image(parsed_features['image'], (32, 32, 3), image_size=self.image_size)
image = tf.identity(image, name='image')
class_id = tf.to_int32(parsed_features['class'], name='class_label')
index = tf.to_int32(parsed_features['id'], name='index')
# Return
records_dict = {'image': image, 'class': class_id, 'id': index}
if self.verbose: print_records(records_dict)
return records_dict
\ No newline at end of file
MNISTMLoader = MNISTLoader
MNISTMLoader.shape = (32, 32, 3)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment