Commit 2e32c272 authored by Amelie Royer's avatar Amelie Royer 🐼

adding split options to get_sample

parent ecbad29c
......@@ -120,11 +120,11 @@ def get_tf_dataset(path_to_tfrecords,
return in_
def get_sample(target_path, loader, compression_type=None, shuffle_buffer=1, batch_size=8):
"""Return data sample"""
def get_sample(target_path, loader, split='train', compression_type=None, shuffle_buffer=1, batch_size=8):
"""Creates a dataset and returns a data sample (batch) from it"""
with tf.Graph().as_default():
data = get_tf_dataset(
'%s_train' % target_path, loader.parsing_fn, compression_type=compression_type,
'%s_%s' % (target_path, split), loader.parsing_fn, compression_type=compression_type,
shuffle_buffer=shuffle_buffer, batch_size=batch_size)
if 'bounding_box' in data:
data['image'] = tf.image.draw_bounding_boxes(
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment