Commit 7990884f authored by Amelie Royer's avatar Amelie Royer 🐼

Adding add_sample function

parent 3a3e406b
......@@ -120,6 +120,19 @@ def get_tf_dataset(path_to_tfrecords,
return in_
def get_sample(target_path, loader, compression_type=None, shuffle_buffer=1, batch_size=8):
"""Return data sample"""
with tf.Graph().as_default():
data = get_tf_dataset(
'%s_train' % target_path, loader.parsing_fn, compression_type=compression_type,
shuffle_buffer=shuffle_buffer, batch_size=batch_size)
if 'bounding_box' in data:
data['image'] = tf.image.draw_bounding_boxes(
data['image'], tf.expand_dims(data['bounding_box'], axis=1))
with tf.Session() as sess:
return sess.run(data)
def decode_raw_image(feature, shape, image_size=None):
"""Decode raw image
Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment