Commit 3e8e4fed authored by Amelie Royer's avatar Amelie Royer 🐼

Adding compressed TFRecords code for VisDA

parent 2e32c272
......@@ -10,10 +10,20 @@ from matplotlib import image as mpimg
import tensorflow as tf
from .tfrecords_utils import *
from .acws import ACWSLoader
VisdaFeatures = Features([
('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('width', FeatureType.INT, FeatureLength.FIXED, (), tf.constant(-1, dtype=tf.int64)),
('height', FeatureType.INT, FeatureLength.FIXED, (), tf.constant(-1, dtype=tf.int64)),
('class', FeatureType.INT, FeatureLength.FIXED, (), None),
class VisdaClassificationConverter(Converter):
features = VisdaFeatures
def __init__(self, data_dir):
"""Initialize the object for the VisDA dataset in `data_dir`"""
self.data_dir = data_dir
......@@ -29,83 +39,39 @@ class VisdaClassificationConverter(Converter):
print('Warning: No %s data found' % name)
def convert(self,
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`"""
for name, split in
for name, split in
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
# For each dir
writer = self.init_writer(writer_path, compression_type=compression_type)
print('\nLoad', name)
for i, aux in enumerate(split):
print('\rImage %d/%d' % (i + 1, len(split)), end='')
feature = {}
# Image
image_path = aux[0]
height, width = None, None
if save_image_in_records:
img = mpimg.imread(os.path.join(self.data_dir, image_path))
if name == 'train': # synthetic
img = img * 255.
feature['image'] = bytes_feature([img.astype(np.uint8).tostring()])
feature['width'] = int64_feature([img.shape[0]])
feature['height'] = int64_feature([img.shape[1]])
height = [img.shape[0]]
width = [img.shape[1]]
img = img.astype(np.uint8).tostring()
feature['image'] = bytes_feature([base64.b64encode(image_path.encode('utf-8'))])
img = base64.b64encode(image_path.encode('utf-8'))
# Class
if len(aux) > 1:
feature['class'] = int64_feature([aux[1]])
class_id = [aux[1]] if len(aux) > 1 else None
# Write
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(self.create_example_proto([img], height, width, class_id))
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
class VisdaClassificationLoader():
classes_names = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife',
class VisdaClassificationLoader(ACWSLoader):
classes_names = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife',
'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck']
def __init__(self,
"""Init a Loader object.
`save_image_in_records` (bool): If True, the image was saved in the record, otherwise only the image path was.
`data_dir` (str): If save_image_in_records is False, append this string to the image_path saved in the record.
`resize` (int): If given, resize the image to the given size
self.save_image_in_records = save_image_in_records
self.image_dir = image_dir
self.image_size = image_resize
self.verbose = verbose
def parsing_fn(self, example_proto):
""" parsing function."""
# Basic features
features = {'image' : tf.FixedLenFeature((), tf.string),
'class': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64)),
'height': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64)),
'width': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64))
parsed_features = tf.parse_single_example(example_proto, features)
# Load image
if self.save_image_in_records:
shape = tf.stack([parsed_features['width'], parsed_features['height'], 3], axis=0)
image = decode_raw_image(parsed_features['image'], shape, image_size=self.image_size)
filename = tf.decode_base64(parsed_features['image'])
parsed_features['image_path'] = tf.identity(filename, name='image_path')
image = decode_relative_image(filename, self.image_dir, image_size=self.image_size)
parsed_features['image'] = tf.identity(image, name='image')
# Class
parsed_features['class'] = tf.to_int32(parsed_features['class'], name='class')
# Return
del parsed_features['height']
del parsed_features['width']
if self.verbose: print_records(parsed_features)
return parsed_features
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment