Commit 9311953a authored by Amelie Royer's avatar Amelie Royer 🐼

Adding compressed TFRecords code for PACS

parent d0b0ddf2
......@@ -12,8 +12,15 @@ import tensorflow as tf
from .tfrecords_utils import *
"""Define features to be stored in the TFRecords"""
PACSFeatures = Features([('image', FeatureType.BYTES, FeatureLength.FIXED, (), None),
('class_style', FeatureType.INT, FeatureLength.FIXED, (), None),
('class_content', FeatureType.INT, FeatureLength.FIXED, (), None)])
class PACSConverter(Converter):
features = PACSFeatures
def __init__(self, data_dir):
"""Initialize the object for the PACS dataset in `data_dir`"""
self.data_dir = data_dir
......@@ -29,24 +36,23 @@ class PACSConverter(Converter):
content_dir = os.path.join(self.data_dir, style, content)
if not os.path.exists(content_dir):
print('Warning: no directory found for content %s in style %s' % (content, style))
continue
continue
self.raw_data[i][j] = sorted([os.path.join(style, content, x)
for x in os.listdir(content_dir) if x.rsplit('.', 1)[1] in ['jpg', 'png']])
self.train_data = None
self.val_data = None
self.test_data = None
self.test_data = None
self.has_generated_split = False
def generate_split(self, split_path, train=0.7, val=0.1, test=0.2):
"""Generate a train, val, test split uniformly over all classes
"""Generate a train, val, test split uniformly over all classes
and export the result as a text file in split_path (0 = train, 1 = val, 2 = test"""
assert train + val + test == 1.0
self.train_data = [[[] for _ in range(7)] for _ in range(4)]
self.val_data = [[[] for _ in range(7)] for _ in range(4)]
self.test_data = [[[] for _ in range(7)] for _ in range(4)]
# Uniform split
for style, d in enumerate(self.raw_data):
for style, d in enumerate(self.raw_data):
for content, image_paths in enumerate(d):
n = len(image_paths)
train_fence = int(n * train)
......@@ -59,17 +65,21 @@ class PACSConverter(Converter):
self.test_data[style][content] = paths[indices[val_fence:]]
# Export to file
with open(split_path, 'w') as f:
f.write('\n'.join('%s 0' % image for style_list in self.train_data for content_list in style_list
f.write('\n'.join('%s 0' % image for style_list in self.train_data for content_list in style_list
for image in content_list))
f.write('\n'.join('%s 1' % image for style_list in self.val_data for content_list in style_list
f.write('\n'.join('%s 1' % image for style_list in self.val_data for content_list in style_list
for image in content_list))
f.write('\n'.join('%s 2' % image for style_list in self.test_data for content_list in style_list
f.write('\n'.join('%s 2' % image for style_list in self.test_data for content_list in style_list
for image in content_list))
self.has_generated_split = True
print('Splits saved in', split_path)
def convert(self, tfrecords_path, save_image_in_records=False, separate_styles=False):
def convert(self,
tfrecords_path,
save_image_in_records=False,
separate_styles=False,
compression_type=None):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`"""
# If no split has been generated, then convert the full data
if self.has_generated_split:
......@@ -77,59 +87,57 @@ class PACSConverter(Converter):
else:
data = zip(['full'], [self.raw_data])
style_names = ['art_painting', 'cartoon', 'photo', 'sketch']
for name, split in data:
if split is None:
for name, split in data:
if split is None:
continue
if not separate_styles:
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
# For each dir
writer = self.init_writer(writer_path, compression_type=compression_type)
print('\nLoad', name)
for s, style_list in enumerate(split):
if separate_styles:
writer_path = '%s_%s_%s' % (tfrecords_path, style_names[s], name)
writer = tf.python_io.TFRecordWriter(writer_path)
writer = self.init_writer(writer_path, compression_type=compression_type)
for c, content_list in enumerate(style_list):
print('\rstyle %d/%d - content %d/%d' % (s + 1, len(split), c + 1, len(style_list)), end='')
for image_path in content_list:
feature = {}
# Image
if save_image_in_records:
img = mpimg.imread(os.path.join(self.data_dir, image_path))
if style_names[s] == 'sketch':
img = img * 255.
img = img[:, :, :3]
feature['image'] = bytes_feature([img.astype(np.uint8).tostring()])
img = img.astype(np.uint8).tostring()
else:
feature['image'] = bytes_feature([base64.b64encode(image_path.encode('utf-8'))])
# Class
feature['class_style'] = int64_feature([s])
feature['class_content'] = int64_feature([c])
img = base64.b64encode(image_path.encode('utf-8'))
# Write
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.write(self.create_example_proto([img], [s], [c]))
if separate_styles:
writer.close()
print('\nWrote %s for style %s in file %s' % (name, style_names[s], writer_path))
print('\nWrote %s for style %s in file %s (%.2fMB)' % (
name, style_names[s], writer_path, os.path.getsize(writer_path) / 1e6))
if not separate_styles:
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print('\nWrote %s in file %s (%.2fMB)' % (
name, writer_path, os.path.getsize(writer_path) / 1e6))
print()
class PACSLoader():
class PACSLoader(Loader):
features = PACSFeatures
style_names = ['art_painting', 'cartoon', 'photo', 'sketch']
content_names = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
def __init__(self,
save_image_in_records=False,
save_image_in_records=False,
image_dir='',
image_size=None,
verbose=False):
"""Init a Loader object.
Args:
`save_image_in_records` (bool): If True, the image was saved in the record, otherwise only the image path was.
`data_dir` (str): If save_image_in_records is False, append this string to the image_path saved in the record.
......@@ -139,17 +147,13 @@ class PACSLoader():
self.image_dir = image_dir
self.image_size = image_size
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'image' : tf.FixedLenFeature((), tf.string),
'class_content': tf.FixedLenFeature((), tf.int64),
'class_style': tf.FixedLenFeature((), tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
# Load image
if self.save_image_in_records:
# Parse
parsed_features = self.raw_parsing_fn(example_proto)
# Reshape
if self.save_image_in_records:
image = decode_raw_image(parsed_features['image'], (227, 227, 3), image_size=self.image_size)
else:
filename = tf.decode_base64(parsed_features['image'])
......@@ -161,4 +165,4 @@ class PACSLoader():
parsed_features['class_style'] = tf.to_int32(parsed_features['class_style'])
# Return
if self.verbose: print_records(parsed_features)
return parsed_features
\ No newline at end of file
return parsed_features
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment