Commit 5887c5c6 authored by Amelie Royer's avatar Amelie Royer

introduce finalize_graph call

parent ca0787cb
...@@ -276,19 +276,21 @@ with tf.name_scope("sampling"): ...@@ -276,19 +276,21 @@ with tf.name_scope("sampling"):
tf.GLOBAL.pop('ema') tf.GLOBAL.pop('ema')
# tensorboard dummies for @generate_samples # tensorboard dummies for @generate_samples
# Reconstruction # Reconstruction (bpd is computed outside of the graph so we only feed a placeholder for Tensorboard)
bpd_rec = tf.get_variable("bpd_rec", initializer=tf.constant(0, dtype=tf.float64)) bpd_rec = tf.get_variable("bpd_rec", initializer=tf.constant(0, dtype=tf.float64))
bpd_rec_ = tf.placeholder(tf.float64, shape=[], name="bpd_rec")
update_bpd_rec = bpd_rec.assign(bpd_rec_)
sum_bpd_rec = tf.summary.scalar("rec_bits_per_dimension", bpd_rec) sum_bpd_rec = tf.summary.scalar("rec_bits_per_dimension", bpd_rec)
rec_error = tf.get_variable("rec_error", initializer=tf.constant(0, dtype=tf.float64)) # Generation (images) Computed outside of the graph
sum_rec_error = tf.summary.scalar("rec_error", rec_error) rec_imgs = tf.placeholder(shape=(1, None, None, C_IN), dtype=tf.float32)
rec_pic_imgs_summary = tf.summary.image("rec_pic", rec_imgs, max_outputs=1)
rec_embd_imgs_summary = tf.summary.image("rec_embd", rec_imgs, max_outputs=1)
# Generation gen_imgs = tf.placeholder(shape=(1, None, None, C_IN), dtype=tf.float32)
gen_error = tf.get_variable("gen_error", initializer=tf.constant(0, dtype=tf.float64)) gen_pic_imgs_summary = tf.summary.image("gen_pic", gen_imgs, max_outputs=1)
sum_gen_error = tf.summary.scalar("gen_error", gen_error) gen_embd_imgs_summary = tf.summary.image("gen_embd", gen_imgs, max_outputs=1)
gen_error_gray = tf.get_variable("gen_error_gray", initializer=tf.constant(0, dtype=tf.float64))
sum_gen_error_gray = tf.summary.scalar("gen_error_gray", gen_error_gray)
############### Generate samples ############### Generate samples
...@@ -306,7 +308,6 @@ def generate_samples(images, sess, summary_writer, reconstruct=False, from_embed ...@@ -306,7 +308,6 @@ def generate_samples(images, sess, summary_writer, reconstruct=False, from_embed
resolution (int): chroma sampling resolution resolution (int): chroma sampling resolution
""" """
global WIDTH, HEIGHT, C_IN, args, samplers_from_pic, samplers_from_embedding global WIDTH, HEIGHT, C_IN, args, samplers_from_pic, samplers_from_embedding
global rec_error, rec_error_gray, bpd_gen, gen_error, gen_error_gray
gray_images = color_to_gray(convert_color(images, colorspace=args.color, normalized_out=True), colorspace=args.color) #BxWxHx1 gray_images = color_to_gray(convert_color(images, colorspace=args.color, normalized_out=True), colorspace=args.color) #BxWxHx1
samplers = samplers_from_embedding if from_embedding else samplers_from_pic samplers = samplers_from_embedding if from_embedding else samplers_from_pic
...@@ -351,17 +352,16 @@ def generate_samples(images, sess, summary_writer, reconstruct=False, from_embed ...@@ -351,17 +352,16 @@ def generate_samples(images, sess, summary_writer, reconstruct=False, from_embed
# Summary # Summary
imgs = tile_image(x_gen) imgs = tile_image(x_gen)
if reconstruct: if reconstruct:
summary_str = sess.run(tf.summary.image("rec%s_%d" % ("embd" if from_embedding else "pic", id), imgs, max_outputs=1)) summary_str = sess.run(rec_embd_imgs_summary if from_embedding else rec_pic_imgs_summary, feed_dict={rec_imgs: imgs})
else: else:
summary_str = sess.run(tf.summary.image("gen%s_%d%s" % ("embd" if from_embedding else "pic", id, '_mode' if sample_mode else ''), imgs, max_outputs=1)) summary_str = sess.run(gen_embd_imgs_summary if from_embedding else gen_pic_imgs_summary, feed_dict={gen_imgs: imgs})
summary_writer.add_summary(summary_str, id) summary_writer.add_summary(summary_str, id)
############### Main ############### Main
inits = tf.global_variables_initializer() inits = tf.global_variables_initializer()
train_summary_op = tf.summary.merge([sum_bpd, sum_bpd_embedding]) train_summary_op = tf.summary.merge([sum_bpd, sum_bpd_embedding])
rec_summary_op = tf.summary.merge([sum_bpd_rec, sum_rec_error]) rec_summary_op = tf.summary.merge([sum_bpd_rec])
gen_summary_op = tf.summary.merge([sum_gen_error, sum_gen_error_gray])
with tf.Session() as sess: with tf.Session() as sess:
### Init saver and summary objects ### Init saver and summary objects
...@@ -379,6 +379,7 @@ with tf.Session() as sess: ...@@ -379,6 +379,7 @@ with tf.Session() as sess:
summary_writer = tf.summary.FileWriter(log_dir, graph=sess.graph) summary_writer = tf.summary.FileWriter(log_dir, graph=sess.graph)
saver = tf.train.Saver() saver = tf.train.Saver()
### Restore model
if args.model: if args.model:
print("Loading model from", "%s%s%s" % (bcolors.YELLOW, args.model, bcolors.RES), "...") print("Loading model from", "%s%s%s" % (bcolors.YELLOW, args.model, bcolors.RES), "...")
saver.restore(sess, args.model) saver.restore(sess, args.model)
...@@ -395,10 +396,15 @@ with tf.Session() as sess: ...@@ -395,10 +396,15 @@ with tf.Session() as sess:
if args.n_generations > 0 and args.mode in ['train', 'test']: if args.n_generations > 0 and args.mode in ['train', 'test']:
summary_str = sess.run(tf.summary.image("original", tile_image(images_test_gen), max_outputs=1)) summary_str = sess.run(tf.summary.image("original", tile_image(images_test_gen), max_outputs=1))
summary_writer.add_summary(summary_str) summary_writer.add_summary(summary_str)
### Training mode ### Training mode
lr = args.learning_rate lr = args.learning_rate
if args.mode == 'train': if args.mode == 'train':
## Finalize the graph
tf.get_default_graph().finalize()
print('\x1b[37mFinal graph size: %.2f MB\x1b[0m' % (tf.get_default_graph().as_graph_def().ByteSize() / 10e6))
## Train
try: try:
for i in range(1, args.epochs + 1): for i in range(1, args.epochs + 1):
### Train ### Train
...@@ -436,7 +442,7 @@ with tf.Session() as sess: ...@@ -436,7 +442,7 @@ with tf.Session() as sess:
avg_bpd_val += bpd avg_bpd_val += bpd
b_val += 1 b_val += 1
print("\r(val) Batch: %d/%d" % (b_val, len(images_test) // (args.batch_size * args.nr_gpus)), ' ' * 35, end = '') print("\r(val) Batch: %d/%d" % (b_val, len(images_test) // (args.batch_size * args.nr_gpus)), ' ' * 35, end = '')
_ = sess.run(bpd_rec.assign(avg_bpd_val / b_val)) _ = sess.run(update_bpd_rec, feed_dict={bpd_rec_: avg_bpd_val / b_val})
### Sampling experiments ### Sampling experiments
if args.gen_epochs > 0 and not i % args.gen_epochs: if args.gen_epochs > 0 and not i % args.gen_epochs:
...@@ -458,7 +464,6 @@ with tf.Session() as sess: ...@@ -458,7 +464,6 @@ with tf.Session() as sess:
"" if b_val <= 0 else "(avg_bpd_val: {}{:.3f}{})".format(bcolors.YELLOW, avg_bpd_val / b_val, bcolors.RES), "" if b_val <= 0 else "(avg_bpd_val: {}{:.3f}{})".format(bcolors.YELLOW, avg_bpd_val / b_val, bcolors.RES),
' ' * 30) ' ' * 30)
summary_writer.add_summary(sess.run(rec_summary_op), i) summary_writer.add_summary(sess.run(rec_summary_op), i)
summary_writer.add_summary(sess.run(gen_summary_op), i)
if args.save_epochs > 0 and not i % args.save_epochs: if args.save_epochs > 0 and not i % args.save_epochs:
saver.save(sess, os.path.join(log_dir, "model.ckpt")) saver.save(sess, os.path.join(log_dir, "model.ckpt"))
...@@ -566,4 +571,4 @@ with tf.Session() as sess: ...@@ -566,4 +571,4 @@ with tf.Session() as sess:
x_gen = np.concatenate([gray_images, x_gen], axis=3) x_gen = np.concatenate([gray_images, x_gen], axis=3)
x_gen = convert_color(x_gen, colorspace=args.color, normalized_in=True, normalized_out=False, reverse=True) x_gen = convert_color(x_gen, colorspace=args.color, normalized_in=True, normalized_out=False, reverse=True)
# save # save
imsave('demo_generations.jpg', tile_image(x_gen)[0]) imsave('demo_generations.jpg', tile_image(x_gen)[0])
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment