Commit 5887c5c6 authored by Amelie Royer's avatar Amelie Royer

introduce finalize_graph call

parent ca0787cb
......@@ -276,19 +276,21 @@ with tf.name_scope("sampling"):
tf.GLOBAL.pop('ema')
# tensorboard dummies for @generate_samples
# Reconstruction
# Reconstruction (bpd is computed outside of the graph so we only feed a placeholder for Tensorboard)
bpd_rec = tf.get_variable("bpd_rec", initializer=tf.constant(0, dtype=tf.float64))
bpd_rec_ = tf.placeholder(tf.float64, shape=[], name="bpd_rec")
update_bpd_rec = bpd_rec.assign(bpd_rec_)
sum_bpd_rec = tf.summary.scalar("rec_bits_per_dimension", bpd_rec)
rec_error = tf.get_variable("rec_error", initializer=tf.constant(0, dtype=tf.float64))
sum_rec_error = tf.summary.scalar("rec_error", rec_error)
# Generation (images) Computed outside of the graph
rec_imgs = tf.placeholder(shape=(1, None, None, C_IN), dtype=tf.float32)
rec_pic_imgs_summary = tf.summary.image("rec_pic", rec_imgs, max_outputs=1)
rec_embd_imgs_summary = tf.summary.image("rec_embd", rec_imgs, max_outputs=1)
# Generation
gen_error = tf.get_variable("gen_error", initializer=tf.constant(0, dtype=tf.float64))
sum_gen_error = tf.summary.scalar("gen_error", gen_error)
gen_imgs = tf.placeholder(shape=(1, None, None, C_IN), dtype=tf.float32)
gen_pic_imgs_summary = tf.summary.image("gen_pic", gen_imgs, max_outputs=1)
gen_embd_imgs_summary = tf.summary.image("gen_embd", gen_imgs, max_outputs=1)
gen_error_gray = tf.get_variable("gen_error_gray", initializer=tf.constant(0, dtype=tf.float64))
sum_gen_error_gray = tf.summary.scalar("gen_error_gray", gen_error_gray)
############### Generate samples
......@@ -306,7 +308,6 @@ def generate_samples(images, sess, summary_writer, reconstruct=False, from_embed
resolution (int): chroma sampling resolution
"""
global WIDTH, HEIGHT, C_IN, args, samplers_from_pic, samplers_from_embedding
global rec_error, rec_error_gray, bpd_gen, gen_error, gen_error_gray
gray_images = color_to_gray(convert_color(images, colorspace=args.color, normalized_out=True), colorspace=args.color) #BxWxHx1
samplers = samplers_from_embedding if from_embedding else samplers_from_pic
......@@ -351,17 +352,16 @@ def generate_samples(images, sess, summary_writer, reconstruct=False, from_embed
# Summary
imgs = tile_image(x_gen)
if reconstruct:
summary_str = sess.run(tf.summary.image("rec%s_%d" % ("embd" if from_embedding else "pic", id), imgs, max_outputs=1))
summary_str = sess.run(rec_embd_imgs_summary if from_embedding else rec_pic_imgs_summary, feed_dict={rec_imgs: imgs})
else:
summary_str = sess.run(tf.summary.image("gen%s_%d%s" % ("embd" if from_embedding else "pic", id, '_mode' if sample_mode else ''), imgs, max_outputs=1))
summary_str = sess.run(gen_embd_imgs_summary if from_embedding else gen_pic_imgs_summary, feed_dict={gen_imgs: imgs})
summary_writer.add_summary(summary_str, id)
############### Main
inits = tf.global_variables_initializer()
train_summary_op = tf.summary.merge([sum_bpd, sum_bpd_embedding])
rec_summary_op = tf.summary.merge([sum_bpd_rec, sum_rec_error])
gen_summary_op = tf.summary.merge([sum_gen_error, sum_gen_error_gray])
rec_summary_op = tf.summary.merge([sum_bpd_rec])
with tf.Session() as sess:
### Init saver and summary objects
......@@ -379,6 +379,7 @@ with tf.Session() as sess:
summary_writer = tf.summary.FileWriter(log_dir, graph=sess.graph)
saver = tf.train.Saver()
### Restore model
if args.model:
print("Loading model from", "%s%s%s" % (bcolors.YELLOW, args.model, bcolors.RES), "...")
saver.restore(sess, args.model)
......@@ -395,10 +396,15 @@ with tf.Session() as sess:
if args.n_generations > 0 and args.mode in ['train', 'test']:
summary_str = sess.run(tf.summary.image("original", tile_image(images_test_gen), max_outputs=1))
summary_writer.add_summary(summary_str)
### Training mode
lr = args.learning_rate
if args.mode == 'train':
## Finalize the graph
tf.get_default_graph().finalize()
print('\x1b[37mFinal graph size: %.2f MB\x1b[0m' % (tf.get_default_graph().as_graph_def().ByteSize() / 10e6))
## Train
try:
for i in range(1, args.epochs + 1):
### Train
......@@ -436,7 +442,7 @@ with tf.Session() as sess:
avg_bpd_val += bpd
b_val += 1
print("\r(val) Batch: %d/%d" % (b_val, len(images_test) // (args.batch_size * args.nr_gpus)), ' ' * 35, end = '')
_ = sess.run(bpd_rec.assign(avg_bpd_val / b_val))
_ = sess.run(update_bpd_rec, feed_dict={bpd_rec_: avg_bpd_val / b_val})
### Sampling experiments
if args.gen_epochs > 0 and not i % args.gen_epochs:
......@@ -458,7 +464,6 @@ with tf.Session() as sess:
"" if b_val <= 0 else "(avg_bpd_val: {}{:.3f}{})".format(bcolors.YELLOW, avg_bpd_val / b_val, bcolors.RES),
' ' * 30)
summary_writer.add_summary(sess.run(rec_summary_op), i)
summary_writer.add_summary(sess.run(gen_summary_op), i)
if args.save_epochs > 0 and not i % args.save_epochs:
saver.save(sess, os.path.join(log_dir, "model.ckpt"))
......@@ -566,4 +571,4 @@ with tf.Session() as sess:
x_gen = np.concatenate([gray_images, x_gen], axis=3)
x_gen = convert_color(x_gen, colorspace=args.color, normalized_in=True, normalized_out=False, reverse=True)
# save
imsave('demo_generations.jpg', tile_image(x_gen)[0])
\ No newline at end of file
imsave('demo_generations.jpg', tile_image(x_gen)[0])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment