Commit be3da7d0 authored by Christoph Sommer's avatar Christoph Sommer

added generator for predictions also

parent d35ac169
......@@ -296,7 +296,8 @@ def train_predict(n_tiles=(1,4,4), params=params, files=None, **unet_config):
else:
datagen.from_file_list(files)
imgs = datagen.load_imgs_generator()
imgs_for_patches = datagen.load_imgs_generator()
imgs_for_predict = datagen.load_imgs_generator()
with Timer('Training and Prediction'):
......@@ -305,7 +306,8 @@ def train_predict(n_tiles=(1,4,4), params=params, files=None, **unet_config):
for c in params["train_channels"]:
print(" -- Channel {}".format(c))
img_ch = (im[..., c:c+1] for im in imgs])
img_ch = (im[..., c:c+1] for im in imgs_for_patches)
img_ch_predict = (im[..., c:c+1] for im in imgs_for_predict)
npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None
......@@ -355,7 +357,7 @@ def train_predict(n_tiles=(1,4,4), params=params, files=None, **unet_config):
plot_history(history,['loss','val_loss'])
print(" -- Predicting channel {}".format(c))
for f, im in zip(files, img_ch):
for f, im in zip(files, img_ch_predict):
print(" -- {}".format(f))
pixel_reso = get_space_time_resolution(str(f))
res_img = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment