Commit 9306cd9f authored by Amelie Royer's avatar Amelie Royer 🐼

Add compressed TFRecords options and factor features code

parent 58c69c56
from abc import ABC, abstractmethod
from enum import Enum
import tensorflow as tf
### Convenience function for writing Feature in TFRecords
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def floats_feature(value):
def float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
### Convenience function for creating a basic tf.data.Dataset
def get_tf_dataset(path_to_tfrecords, parsing_fn, shuffle_buffer=1, batch_size=8):
### Base Features Class
class FeatureType(Enum):
INT = (tf.int64, int64_feature)
FLOAT = (tf.float32, float_feature)
BYTES = (tf.string, bytes_feature)
class FeatureLength(Enum):
FIXED = 0
VAR = 1
class Features:
def __init__(self, feature_list):
"""A feature_list is a list of tuple containing each feature name, type, whether it is of
fixed or variable length, and its shape.
Where type is one of 'int64', 'string' or 'float'."""
# Features dictionnary (reading)
self.features_read = {name: (tf.FixedLenFeature(shape, feature_type.value[0])
if feature_length == FeatureLength.FIXED else
tf.VarLenFeature(feature_type.value[0]))
for name, feature_type, feature_length, shape in feature_list}
# Featured dictionnary (writing)
self.features_write = [(name, feature_type.value[1]) for name, feature_type, _, _ in feature_list]
### Base converter class
class Converter(ABC):
@property
def features(self):
raise NotImplementedError
@abstractmethod
def __init__(self, data_dir):
pass
def init_writer(self, writer_path, compression_type=None):
"""Returns the TFRecordWriter object writing to the given path"""
assert compression_type in [None, 'gzip', 'zlib']
writer_options = None
if compression_type == 'gzip':
writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
elif compression_type == 'zlib':
writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
return tf.python_io.TFRecordWriter(writer_path, options=writer_options)
def create_example_proto(self, *args):
"""Create a TFRecords example protobuffer from the given arguments"""
feature = {name: fn(x) for x, (name, fn) in zip(args, self.features.features_write)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()
@abstractmethod
def convert(self, tfrecords_path, compression_type=None):
"""Convert the dataset to TFRecords format"""
pass
### Base loader class
class Loader(ABC):
@property
def features(self):
raise NotImplementedError
@abstractmethod
def __init__(self):
pass
def raw_parsing_fn(self, example_proto):
return tf.parse_single_example(example_proto, self.features.features_read)
@abstractmethod
def parsing_fn(self, example_proto):
pass
### Other convenience functions for parsing
def get_tf_dataset(path_to_tfrecords,
parsing_fn,
compression_type=None,
compression_buffer=0,
shuffle_buffer=1,
batch_size=8):
"""Create a basic one-shot tensorflow Dataset object from a TFRecords.
Args:
path_to_tfrecords: Path to the TFrecords
parsing_fn: parsing function to apply to every element (load Examples)
shuffle_buffer: Shuffle buffer size to randomize the dataset
batch_size: Batch size
"""
print('[dataset] batch_size = %d, shuffle buffer = %d' % (
batch_size, shuffle_buffer))
data = tf.data.TFRecordDataset(path_to_tfrecords)
assert compression_type in [None, 'gzip', 'zlib']
data = tf.data.TFRecordDataset(path_to_tfrecords,
compression_type=compression_type.upper(),
buffer_size=compression_buffer)
data = data.shuffle(shuffle_buffer)
data = data.map(parsing_fn)
data = data.batch(batch_size)
......@@ -67,16 +151,10 @@ def decode_relative_image(filename, image_dir, image_size=None):
return image
def print_records(records_dict):
"""Print a dictionnary for verbose mode"""
print('\u001b[36mOutputs:\u001b[0m')
print('\n'.join(' \u001b[46m%s\u001b[0m: %s' % (key, records_dict[key])
for key in sorted(records_dict.keys())))
def make_square_bounding_box(bounding_box, mode='max'):
"""Given a bounding box [ymin, xmin, ymax, xmax] in [0., 1.], compute a square bounding box centered around it,
whose side is equal to the maximum or minimum side
"""Given a bounding box [ymin, xmin, ymax, xmax] in [0., 1.],
compute a square bounding box centered around it,
whose side is equal to the maximum or minimum side
"""
assert mode in ['max', 'min']
width = bounding_box[2] - bounding_box[0]
......@@ -86,24 +164,10 @@ def make_square_bounding_box(bounding_box, mode='max'):
offset_y = (size - height) / 2.
offset = tf.stack([- offset_x, - offset_y, offset_x, offset_y], axis=0)
return bounding_box + offset
### Base converter class
class Converter(ABC):
@abstractmethod
def __init__(self, data_dir):
pass
@abstractmethod
def convert(self, tfrecords_path):
pass
### Base loader class
class Loader(ABC):
@abstractmethod
def __init__(self):
pass
@abstractmethod
def parsing_fn(self, example_proto):
pass
\ No newline at end of file
def print_records(records_dict):
"""Pretty-print a dictionary for verbose mode"""
print('\u001b[36mContents:\u001b[0m')
print('\n'.join(' \u001b[46m%s\u001b[0m: %s' % (key, records_dict[key])
for key in sorted(records_dict.keys())))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment