Commit 56a6774c authored by Amelie Royer's avatar Amelie Royer 🐼

Adding compressed TFRecords code for TinyImageNet

parent 8bc0cc8e
......@@ -12,8 +12,19 @@ import tensorflow as tf
from .tfrecords_utils import *
"""Define features to be stored in the TFRecords"""
TinyImageNetFeatures = Features([
('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('width', FeatureType.INT, FeatureLength.FIXED, (), tf.constant(-1, dtype=tf.int64)),
('height', FeatureType.INT, FeatureLength.FIXED, (), tf.constant(-1, dtype=tf.int64)),
('bounding_box', FeatureType.FLOAT, FeatureLength.FIXED, (4,), tf.constant([0., 0., 1., 1.])),
('class', FeatureType.INT, FeatureLength.FIXED, (), -1),
('class_str', FeatureType.BYTES, FeatureLength.FIXED, (), '')
])
class TinyImageNetConverter(Converter):
features = TinyImageNetFeatures
def __init__(self, data_dir):
"""Initialize the object for the TinyImageNet dataset in `data_dir`"""
self.data_dir = data_dir
......@@ -54,54 +65,59 @@ class TinyImageNetConverter(Converter):
else:
print('Warning: Missing test data')
def convert(self, tfrecords_path, save_image_in_records=False):
def convert(self,
tfrecords_path,
compression_type=None,
save_image_in_records=False):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`"""
for name, data in self.data:
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
writer = self.init_writer(writer_path, compression_type=compression_type)
print('\nLoad', name)
for i, item in enumerate(data):
print('\rImage %d/%d' % (i + 1, len(data)), end='')
feature = {}
# Image
# Image shape
image_path = item[0]
height, width = None, None
if save_image_in_records or len(item) > 2:
img = mpimg.imread(os.path.join(self.data_dir, image_path))
height, width = img.shape[:2]
if save_image_in_records:
feature['image'] = bytes_feature([img.astype(np.uint8).tostring()])
feature['width'] = int64_feature([img.shape[0]])
feature['height'] = int64_feature([img.shape[1]])
else:
feature['image'] = bytes_feature([base64.b64encode(image_path.encode('utf-8'))])
height = [img.shape[0]]
width = [img.shape[1]]
# Image
img = (img.astype(np.uint8).tostring() if save_image_in_records else
base64.b64encode(image_path.encode('utf-8')))
# Class
class_id, class_name, bbox = None, None, None
if len(item) > 1:
class_id = self.synsets_to_ids[item[1]]
feature['class'] = int64_feature([class_id])
class_id = [self.synsets_to_ids[item[1]]]
class_name = self.synsets_to_labels[item[1]]
class_name = base64.b64encode(class_name.encode('utf-8'))
feature['class_str'] = bytes_feature([class_name])
class_name = [base64.b64encode(class_name.encode('utf-8'))]
# Normalized Bounding box
if len(item) > 2:
bbox = np.array([item[2][1] / width, item[2][0] / height,
item[2][3] / width, item[2][2] / height], dtype=np.float32)
feature['bounding_box'] = floats_feature(bbox.flatten())
bbox = np.array([item[2][1] / width[0], item[2][0] / height[0],
item[2][3] / width[0], item[2][2] / height[0]], dtype=np.float32)
bbox = bbox.flatten()
# Write
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.write(self.create_example_proto([img], height, width, bbox, class_id, class_name))
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
print()
class TinyImageNetLoader():
class TinyImageNetLoader(Loader):
features = TinyImageNetFeatures
def __init__(self,
save_image_in_records=False,
save_image_in_records=False,
image_dir='',
image_size=None,
verbose=False):
"""Init a Loader object.
"""Init a Loader object.
Args:
`save_image_in_records` (bool): If True, the image was saved in the record, otherwise only the image path was.
`image_dir` (str): If save_image_in_records is False, append this string to the image_path saved in the record.
......@@ -111,21 +127,13 @@ class TinyImageNetLoader():
self.image_dir = image_dir
self.image_size = image_size
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'image' : tf.FixedLenFeature((), tf.string),
'class': tf.FixedLenFeature((), tf.int64, default_value=-1),
'class_str': tf.FixedLenFeature((), tf.string, default_value=''),
'bounding_box': tf.FixedLenFeature((4,), tf.float32,
default_value=tf.constant([0., 0., 1., 1.], dtype=tf.float32)),
'height': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64)),
'width': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64))
}
parsed_features = tf.parse_single_example(example_proto, features)
# Load image
if self.save_image_in_records:
# Parse
parsed_features = self.raw_parsing_fn(example_proto)
# Reshape
if self.save_image_in_records:
shape = tf.stack([parsed_features['width'], parsed_features['height'], 3], axis=0)
image = decode_raw_image(parsed_features['image'], shape, image_size=self.image_size)
else:
......@@ -139,4 +147,4 @@ class TinyImageNetLoader():
del parsed_features['height']
del parsed_features['width']
if self.verbose: print_records(parsed_features)
return parsed_features
\ No newline at end of file
return parsed_features
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment