Commit 365ea9b7 authored by Christoph Sommer's avatar Christoph Sommer

,ompr clean up

parent 4c702419
......@@ -217,9 +217,7 @@ def select_n2v_parameter():
widgets.HBox([widgets.Label('Neighborhood radius', layout={'width':'100px'}), int_n2v_neighborhood_radius]),
widgets.HBox([widgets.Label('Perc. pixel manipulation', layout={'width':'100px'}), float_n2v_perc_pix]),
widgets.HBox([widgets.Label('Model name', layout={'width':'100px'}), text_n2v_name]),
])
])
display(n2v_parameter)
......@@ -235,6 +233,7 @@ def train_predict(n_tiles=(1,4,4), params=params):
datagen = BFListReader(params["in_dir"], params["glob"])
files = datagen.img_fns
print("Loading images ...")
imgs = datagen.load_imgs()
for c in params["train_channels"]:
......@@ -243,15 +242,15 @@ def train_predict(n_tiles=(1,4,4), params=params):
img_ch = [im[..., c:c+1] for im in imgs]
npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None
print("npatches", npatches)
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'])
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=False)
sep = int(len(patches)*0.9)
X = patches[:sep]
X_val = patches[ sep:]
print("X.shape", X.shape)
print(img_ch[0].shape, img_ch[0].shape)
print("img_ch.shape", [iii.shape for iii in img_ch])
config = N2VConfig(X,
......@@ -277,27 +276,19 @@ def train_predict(n_tiles=(1,4,4), params=params):
history = model.train(X, X_val)
val_patch = X_val[0,..., 0]
print("val_patch.shape", val_patch.shape)
val_patch_pred = model.predict(val_patch,axes=params["axes"])
print("val_patch_pred.shape", val_patch_pred.shape)
# Let's look at two patches.
val_patch_pred = model.predict(val_patch, axes=params["axes"])
f, ax = plt.subplots(1,2, figsize=(14,7))
ax[0].imshow(val_patch[0, ...],cmap='gray')
ax[0].set_title('Validation Patch')
ax[1].imshow(val_patch_pred[0, ...],cmap='gray')
ax[1].set_title('Validation Patch N2V')
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'])
print("Predicting channel {}".format(c))
for f, im in zip(files, img_ch):
print("Predicting {}".format(f))
print(" -- {}".format(f))
pixel_reso = get_space_time_resolution(str(f))
res_img = []
for t in range(len(im)):
......@@ -307,8 +298,7 @@ def train_predict(n_tiles=(1,4,4), params=params):
pred = numpy.stack(res_img)
reso = (1 / pixel_reso.X,
1 / pixel_reso.Y )
reso = (1 / pixel_reso.X, 1 / pixel_reso.Y )
spacing = pixel_reso.Z
unit = pixel_reso.Xunit
finterval = pixel_reso.T
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment