Commit eb3cc52b authored by Amelie Royer's avatar Amelie Royer 🐼

Adding compressed TFRecords code for MNIST

parent 9306cd9f
......@@ -8,15 +8,21 @@ import os
import numpy as np
import tensorflow as tf
from .tfrecords_utils import *
from .tfrecords_utils import *
def read_integer(bytel):
return int('0x' + ''.join('{:02x}'.format(x) for x in bytel), 0)
"""Define features to be stored in the TFRecords"""
MNISTFeatures = Features([('class', FeatureType.INT, FeatureLength.FIXED, (),),
('image', FeatureType.BYTES, FeatureLength.FIXED, (),),
('id', FeatureType.INT, FeatureLength.FIXED, (),)])
class MNISTConverter(Converter):
features = MNISTFeatures
def __init__(self, data_dir):
"""Initialize the object for the MNIST dataset in `data_dir`"""
print('Loading original MNIST data from', data_dir)
......@@ -24,19 +30,16 @@ class MNISTConverter(Converter):
for name, key in [('train', 'train'), ('test', 't10k')]:
images = os.path.join(data_dir, '%s-images.idx3-ubyte' % key)
labels = os.path.join(data_dir, '%s-labels.idx1-ubyte' % key)
if not os.path.isfile(images) or not os.path.isfile(labels):
if not os.path.isfile(images) or not os.path.isfile(labels):
print('Warning: Missing %s data' % name)
else:
self.data.append((name, images, labels))
def convert(self, tfrecords_path, sort=False):
def convert(self, tfrecords_path, compression_type=None, sort=False):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`
If `sort` is True, the Example will be sorted by class in the final TFRecords.
"""
for name, images, labels in self.data:
if images is None or labels is None:
print('Warning: Missing %s data' % name)
continue
for name, images, labels in self.data:
# Read images
with codecs.open(images, 'r', 'latin-1') as f:
block = list(bytearray(f.read(), 'latin-1'))
......@@ -44,12 +47,14 @@ class MNISTConverter(Converter):
num_items = read_integer(block[4:8])
num_rows = read_integer(block[8:12])
num_columns = read_integer(block[12:16])
# Read labels
with codecs.open(labels, 'r', 'latin-1') as f:
blockLabels = list(bytearray(f.read(), 'latin-1'))[8:]
# Write
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
writer = self.init_writer(writer_path, compression_type=compression_type)
offset = 16
num_pixels = num_rows * num_columns
labels_order = np.argsort(blockLabels) if sort else range(num_items)
......@@ -57,40 +62,36 @@ class MNISTConverter(Converter):
print('\rLoad %s: %d / %d' % (name, x + 1, num_items), end='')
step = offset + index * num_pixels
next_step = step + num_pixels
img = np.array(block[step:next_step]).reshape((num_rows, num_columns))
img = img.astype(np.uint8)
img = np.array(block[step:next_step], dtype=np.uint8).reshape((num_rows, num_columns))
class_id = blockLabels[index]
writer.write(self.create_example_proto([class_id], [img.tostring()], [index]))
step = next_step
example = tf.train.Example(features=tf.train.Features(feature={
'class': int64_feature([class_id]),
'image': bytes_feature([img.tostring()]),
'id': int64_feature([index])}))
writer.write(example.SerializeToString())
# End
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
print()
class MNISTLoader():
class MNISTLoader(Loader):
features = MNISTFeatures
shape = (28, 28, 1)
def __init__(self, image_size=None, verbose=False):
"""Init a Loader object. Loaded images will be resized to size `image_size`."""
self.image_size = image_size
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'class': tf.FixedLenFeature((), tf.int64),
'image': tf.FixedLenFeature((), tf.string),
'id': tf.FixedLenFeature((), tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
# Parse
image = decode_raw_image(parsed_features['image'], (28, 28, 1), image_size=self.image_size)
image = tf.identity(image, name='image')
class_id = tf.to_int32(parsed_features['class'], name='class_label')
index = tf.to_int32(parsed_features['id'], name='index')
parsed_features = self.raw_parsing_fn(example_proto)
# Reshape
parsed_features['image'] = decode_raw_image(parsed_features['image'], self.shape, image_size=self.image_size)
parsed_features['image'] = tf.identity(parsed_features['image'], name='image')
parsed_features['class'] = tf.to_int32(parsed_features['class'], name='class_label')
parsed_features['id'] = tf.to_int32(parsed_features['id'], name='index')
# Return
records_dict = {'image': image, 'class': class_id, 'id': index}
if self.verbose: print_records(records_dict)
return records_dict
\ No newline at end of file
if self.verbose: print_records(parsed_features)
return parsed_features
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment