Commit 9b999c8b authored by Amelie Royer's avatar Amelie Royer 🐼

Adding default value for FixedLenFeature

parent b3d8ccd2
......@@ -18,14 +18,14 @@ def unpickle(file_path):
"""Define features to be stored in the TFRecords"""
CIFAR10Features = Features([('image', FeatureType.BYTES, FeatureLength.FIXED, (),),
('class', FeatureType.INT, FeatureLength.FIXED, (),),
('class_str', FeatureType.BYTES, FeatureLength.FIXED, (),)])
CIFAR100Features = Features([('image', FeatureType.BYTES, FeatureLength.FIXED, (),),
('class', FeatureType.INT, FeatureLength.FIXED, (),),
('coarse_class', FeatureType.INT, FeatureLength.FIXED, (),),
('coarse_class_str', FeatureType.BYTES, FeatureLength.FIXED, (),)])
CIFAR10Features = Features([('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('class', FeatureType.INT, FeatureLength.FIXED, (), None),
('class_str', FeatureType.BYTES, FeatureLength.FIXED, (), None)])
CIFAR100Features = Features([('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('class', FeatureType.INT, FeatureLength.FIXED, (), None),
('coarse_class', FeatureType.INT, FeatureLength.FIXED, (), None),
('coarse_class_str', FeatureType.BYTES, FeatureLength.FIXED, (), None)])
class CIFAR10Converter(Converter):
......
......@@ -15,9 +15,9 @@ def read_integer(bytel):
return int('0x' + ''.join('{:02x}'.format(x) for x in bytel), 0)
"""Define features to be stored in the TFRecords"""
MNISTFeatures = Features([('class', FeatureType.INT, FeatureLength.FIXED, (),),
('image', FeatureType.BYTES, FeatureLength.FIXED, (),),
('id', FeatureType.INT, FeatureLength.FIXED, (),)])
MNISTFeatures = Features([('class', FeatureType.INT, FeatureLength.FIXED, (), None),
('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('id', FeatureType.INT, FeatureLength.FIXED, (), None)])
class MNISTConverter(Converter):
......
......@@ -29,15 +29,15 @@ class FeatureLength(Enum):
class Features:
def __init__(self, feature_list):
"""A feature_list is a list of tuple containing each feature name, type, whether it is of
fixed or variable length, and its shape.
fixed or variable length, its shape, and optional default value.
Where type is one of 'int64', 'string' or 'float'."""
# Features dictionnary (reading)
self.features_read = {name: (tf.FixedLenFeature(shape, feature_type.value[0])
self.features_read = {name: (tf.FixedLenFeature(shape, feature_type.value[0], default_value=default)
if feature_length == FeatureLength.FIXED else
tf.VarLenFeature(feature_type.value[0]))
for name, feature_type, feature_length, shape in feature_list}
for name, feature_type, feature_length, shape, default in feature_list}
# Featured dictionnary (writing)
self.features_write = [(name, feature_type.value[1]) for name, feature_type, _, _ in feature_list]
self.features_write = [(name, feature_type.value[1]) for name, feature_type, _, _, _ in feature_list]
### Base converter class
......@@ -61,8 +61,8 @@ class Converter(ABC):
return tf.python_io.TFRecordWriter(writer_path, options=writer_options)
def create_example_proto(self, *args):
"""Create a TFRecords example protobuffer from the given arguments"""
feature = {name: fn(x) for x, (name, fn) in zip(args, self.features.features_write)}
"""Create a TFRecords example protobuffer from the given arguments. Ignore `None` values"""
feature = {name: fn(x) for x, (name, fn) in zip(args, self.features.features_write) if x is not None}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment