Commit a020f8b8 authored by Amelie Royer's avatar Amelie Royer

Adding CIFAR-10

parent 497153f4
......@@ -32,10 +32,11 @@ The loader simply builds a proper parsing function to extract data from the TFRe
### Table of Contents
| Dataset | Link | Example | TFRecords contents |
| ------- | ---- | ------ | --- |
| :-----: | :--: | :-----: | :----------------: |
| ACwS | [Apparel Classification with Style](http://www.vision.ee.ethz.ch/~lbossard/projects/accv12/index.html) | ![acws_thumb](images/acws.png) | image, class |
| CelebA | [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) | ![celeba_thumb](images/celeba.png) | image, bounding-box, attributes, landmarks |
| CartoonSet | [CartoonSet](https://google.github.io/cartoonset/) | ![cartoonset_thumb](images/cartoonset.png) | image, bounding-box, attributes |
| CIFAR-10 | [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) | ![cifar10_thumb](images/cifar10.png) | image, class, class-name |
| Fashion MNIST| [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) | ![fashion_mnist_thumb](images/fashion_mnist.png) | image, class, index|
| MNIST | [MNIST](http://yann.lecun.com/exdb/mnist/) | ![mnist_thumb](images/mnist.png) | image, digit-class, index |
| MNIST-M | [MNIST-M](http://yaroslav.ganin.net/) | ![mnistm_thumb](images/mnistm.png) | image, digit-class, index |
......
......@@ -4,7 +4,6 @@ from __future__ import print_function
# http://www.vision.ee.ethz.ch/~lbossard/projects/accv12/index.html #
#####################################################################
import base64
import csv
import os
import numpy as np
from matplotlib import image as mpimg
......
......@@ -4,7 +4,6 @@ from __future__ import print_function
# http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html #
####################################################
import base64
import csv
import os
import numpy as np
from matplotlib import image as mpimg
......
from __future__ import print_function
###############################################
# CIFAR-10 #
# https://www.cs.toronto.edu/~kriz/cifar.html #
###############################################
import base64
import pickle
import os
import numpy as np
from matplotlib import image as mpimg
import tensorflow as tf
from .tfrecords_utils import *
def unpickle(file_path):
with open(file_path, 'rb') as f:
return pickle.load(f, encoding='bytes')
class CIFAR10Converter(Converter):
def __init__(self, data_dir):
"""Initialize the object for the CIFAR-10 dataset in `data_dir`"""
self.data_dir = data_dir
self.data = []
# Train
train_batches = []
for i in range(5):
b = os.path.join(self.data_dir, 'data_batch_%d' % (i + 1))
if not os.path.isfile(b):
print('Warning: Missing train batch', i + 1)
else:
train_batches.append(b)
if len(train_batches):
self.data.append(('train', train_batches))
# Test
test_batch = os.path.join(self.data_dir, 'test_batch')
if not os.path.isfile(test_batch):
print('Warning: Missing test batch')
else:
self.data.append(('test', [test_batch]))
# Labels
self.label_names = unpickle(os.path.join(self.data_dir, 'batches.meta'))[b'label_names']
def convert(self, tfrecords_path, save_image_in_records=False):
"""Convert the dataset in TFRecords saved in the given `tfrecords_path`"""
for name, data in self.data:
writer_path = '%s_%s' % (tfrecords_path, name)
writer = tf.python_io.TFRecordWriter(writer_path)
print('\nLoad', name)
for i, item in enumerate(data):
print('\rBatch %d/%d' % (i + 1, len(data)), end='')
d = unpickle(item)
for img, label in zip(d[b'data'], d[b'labels']):
class_name = self.label_names[label]
img = np.transpose(np.reshape(img, (3, 32, 32)), (1, 2, 0))
example = tf.train.Example(features=tf.train.Features(
feature={'image': bytes_feature([img.astype(np.uint8).tostring(order='C')]),
'class': int64_feature([label]),
'class_str': bytes_feature([base64.b64encode(class_name)])}))
writer.write(example.SerializeToString())
writer.close()
print('\nWrote %s in file %s' % (name, writer_path))
print()
class CIFAR10Loader():
def __init__(self,
image_size=None,
verbose=False):
"""Init a Loader object."""
self.image_size = image_size
self.verbose = verbose
def parsing_fn(self, example_proto):
"""tf.data.Dataset parsing function."""
# Basic features
features = {'image' : tf.FixedLenFeature((), tf.string),
'class': tf.FixedLenFeature((), tf.int64),
'class_str': tf.FixedLenFeature((), tf.string),
}
parsed_features = tf.parse_single_example(example_proto, features)
image = decode_raw_image(parsed_features['image'], (32, 32, 3), image_size=self.image_size)
parsed_features['image'] = tf.identity(image, name='image')
parsed_features['class'] = tf.to_int32(parsed_features['class'])
parsed_features['class_str'] = tf.decode_base64(parsed_features['class_str'])
# Return
if self.verbose: print_records(parsed_features)
return parsed_features
\ No newline at end of file
......@@ -4,13 +4,13 @@ from __future__ import print_function
# http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 #
######################################################
import base64
import csv
import os
import numpy as np
from matplotlib import image as mpimg
from .tfrecords_utils import *
import tensorflow as tf
from .tfrecords_utils import *
class PACSConverter(Converter):
......
......@@ -4,7 +4,6 @@ from __future__ import print_function
# https://tiny-imagenet.herokuapp.com/ #
########################################
import base64
import csv
import os
import numpy as np
from matplotlib import image as mpimg
......@@ -117,9 +116,10 @@ class TinyImageNetLoader():
"""tf.data.Dataset parsing function."""
# Basic features
features = {'image' : tf.FixedLenFeature((), tf.string),
'class': tf.FixedLenFeature((), tf.int64),
'class_str': tf.FixedLenFeature((), tf.string),
'bounding_box': tf.FixedLenFeature((4,), tf.float32),
'class': tf.FixedLenFeature((), tf.int64, default_value=-1),
'class_str': tf.FixedLenFeature((), tf.string, default_value=''),
'bounding_box': tf.FixedLenFeature((4,), tf.float32,
default_value=tf.constant([0., 0., 1., 1.], dtype=tf.float32)),
'height': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64)),
'width': tf.FixedLenFeature((), tf.int64, default_value=tf.constant(-1, dtype=tf.int64))
}
......
......@@ -4,13 +4,13 @@ from __future__ import print_function
# http://ai.bu.edu/visda-2017/ #
################################
import base64
import csv
import os
import numpy as np
from matplotlib import image as mpimg
from .tfrecords_utils import *
import tensorflow as tf
from .tfrecords_utils import *
class VisdaClassificationConverter(Converter):
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment