Commit dd56c98d authored by Amelie Royer's avatar Amelie Royer 🐼

Adding compressed TFRecords code for M2NIST

parent 56a6774c
......@@ -7,11 +7,18 @@ import os
import numpy as np
import tensorflow as tf
from .tfrecords_utils import *
from .tfrecords_utils import *
"""Define features to be stored in the TFRecords"""
M2NISTFeatures = Features([('mask', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('id', FeatureType.INT, FeatureLength.FIXED, (), None)])
class M2NISTConverter(Converter):
features = M2NISTFeatures
def __init__(self, data_dir):
"""Initialize the object for the M2NIST dataset in `data_dir`"""
print('Loading original Multidigit MNIST data from', data_dir)
......@@ -20,7 +27,7 @@ class M2NISTConverter(Converter):
assert os.path.isfile(self.images)
assert os.path.isfile(self.masks)
def convert(self, tfrecords_path, train_split=0.7, val_split=0.1, test_split=0.2):
def convert(self, tfrecords_path, train_split=0.7, val_split=0.1, test_split=0.2, compression_type=None):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`
If `sort` is True, the Example will be sorted by class in the final TFRecords.
"""
......@@ -36,26 +43,23 @@ class M2NISTConverter(Converter):
# Write
for name, start, end in [('train', 0, train_fence),
('val', train_fence, val_fence),
('test', val_fence, test_fence)]:
if start == end:
('test', val_fence, test_fence)]:
if start == end:
print('Warning: Empty %s split' % name)
continue
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
writer = self.init_writer(writer_path, compression_type=compression_type)
for i, index in enumerate(indices[start:end]):
print('\rLoad %s: %d / %d' % (name, i + 1, end - start), end='')
img = images[i].astype(np.uint8)
mask = masks[i, :, :, :10].astype(np.float32)
example = tf.train.Example(features=tf.train.Features(feature={
'mask': bytes_feature([mask.tostring()]),
'image': bytes_feature([img.tostring()]),
'id': int64_feature([index])}))
writer.write(example.SerializeToString())
writer.write(self.create_example_proto([mask.tostring()], [img.tostring()], [index]))
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
print()
def viz_mask(mask):
"""Given a (batch, w, h, 10) array, returns a visualization"""
rgb_palette = np.array([(248, 183, 205), (246, 210, 224), (200, 231, 245), (103, 163, 217), (6, 113, 183),
......@@ -64,34 +68,32 @@ def viz_mask(mask):
mask = np.tile(mask, (1, 1, 1, 1, 3))
mask = np.sum(mask * rgb_palette, axis=-2)
mask = np.clip(mask, 0, 255)
return mask
class M2NISTLoader():
return mask
class M2NISTLoader(Loader):
features = M2NISTFeatures
def __init__(self, image_size=None, verbose=False):
"""Init a Loader object. Loaded images will be resized to size `image_size`."""
self.image_size = image_size
self.verbose = verbose
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'mask': tf.FixedLenFeature((), tf.string),
'image': tf.FixedLenFeature((), tf.string),
'id': tf.FixedLenFeature((), tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
# Image
image = decode_raw_image(parsed_features['image'], (64, 84, 1), image_size=self.image_size)
image = tf.identity(image, name='image')
# Mask
mask = tf.decode_raw(parsed_features['mask'], tf.float32)
mask = tf.reshape(mask, (64, 84, 10))
# Parse
parsed_features = self.raw_parsing_fn(example_proto)
# Reshape
parsed_features['image'] = decode_raw_image(parsed_features['image'], (64, 84, 1), image_size=self.image_size)
parsed_features['image'] = tf.identity(parsed_features['image'], name='image')
parsed_features['mask'] = tf.decode_raw(parsed_features['mask'], tf.float32)
parsed_features['mask'] = tf.reshape(parsed_features['mask'], (64, 84, 10))
if self.image_size is not None:
mask = tf.image.resize_images(mask, (self.image_size, self.image_size), method=tf.image.ResizeMethod.BILINEAR)
mask = tf.to_float(mask > 0.5)
parsed_features['mask'] = tf.image.resize_images(
parsed_features['mask'], (self.image_size, self.image_size), method=tf.image.ResizeMethod.BILINEAR)
parsed_features['mask'] = tf.to_float(parsed_features['mask'] > 0.5)
# Index
index = tf.to_int32(parsed_features['id'], name='index')
parsed_features['id'] = tf.to_int32(parsed_features['id'], name='index')
# Return
records_dict = {'image': image, 'mask': mask, 'id': index}
if self.verbose: print_records(records_dict)
return records_dict
\ No newline at end of file
if self.verbose: print_records(parsed_features)
return parsed_features
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment