Commit 72222181 authored by Christoph Sommer's avatar Christoph Sommer

working for 2d use cases

parent e1eb399e
......@@ -188,8 +188,6 @@ class BFListReader(object):
dim = self.get_axes(fn)
dims.append(dim)
print(dims)
assert dims.count(dims[0]) == len(dims), "Dimensions of image files do not match"
self.axes = dims[0]
......
......@@ -32,12 +32,13 @@ class GuiParamsN2V(GuiParams):
self['n2v_perc_pix'] = 0.016
self['n2v_patch_shape'] = []
self['n2v_neighborhood_radius'] = 5
self['augment'] = False
params = GuiParamsN2V()
params.initialize()
params.load("C:/Users/csommer/Desktop/bif_n2v2.json")
params.load("H:/projects/024_care_bif/n2v/test_CTYX/bif_n2v.json")
select_project = partial(select_project, default_name='./bif_n2v.json', params=params)
select_train_paramter = partial(select_train_paramter, params=params)
......@@ -128,8 +129,8 @@ def select_input(params=params):
##################
def select_channel():
available_channels = list(range(get_pixel_dimensions(get_file_list(params["in_dir"], params["glob"])[0]).c))
available_channels_str = list(map(str, available_channels))
channel_str = list(map(str, params["train_channels"]))
ms_channel = widgets.widgets.SelectMultiple(
......@@ -142,7 +143,7 @@ def select_channel():
ms_channel.observe(on_channel_change, 'value')
ms_channel.value = channel_str
display(widgets.HBox([widgets.Label("Channels", layout={'width':'100px'}), ms_channel]))
display(widgets.HBox([widgets.Label("Channels", layout={'width':'200px'}), ms_channel]))
def select_patch_parameter():
### Path size select
......@@ -169,7 +170,7 @@ def select_patch_parameter():
patch_size_select = widgets.HBox(patch_size_select)
display(widgets.HBox([widgets.Label('Patch size', layout={'width':'100px'}), patch_size_select]))
display(widgets.HBox([widgets.Label('Patch size', layout={'width':'200px'}), patch_size_select]))
def select_npatch_per_image():
......@@ -181,7 +182,23 @@ def select_npatch_per_image():
dd_n_patch_per_img.observe(on_n_patch_per_img_change, 'value')
display(widgets.HBox([widgets.Label('#Patches per image', layout={'width':'100px'}), dd_n_patch_per_img]))
display(widgets.HBox([widgets.Label('#Patches per image (-1: all)', layout={'width':'200px'}), dd_n_patch_per_img]))
def select_augment():
dd_augment = widgets.Dropdown(
options=[('8 rotation/flips', True), ('No augment', False)],
value=False,
)
def on_dd_augment_change(change):
params['augment'] = change.new
dd_augment.observe(on_dd_augment_change, 'value')
display(widgets.HBox([widgets.Label('Augment', layout={'width':'200px'}), dd_augment]))
def select_n2v_parameter():
### N2V neighbor radius
......@@ -214,9 +231,9 @@ def select_n2v_parameter():
### Combine
##############
n2v_parameter = widgets.VBox([
widgets.HBox([widgets.Label('Neighborhood radius', layout={'width':'100px'}), int_n2v_neighborhood_radius]),
widgets.HBox([widgets.Label('Perc. pixel manipulation', layout={'width':'100px'}), float_n2v_perc_pix]),
widgets.HBox([widgets.Label('Model name', layout={'width':'100px'}), text_n2v_name]),
widgets.HBox([widgets.Label('Neighborhood radius', layout={'width':'200px'}), int_n2v_neighborhood_radius]),
widgets.HBox([widgets.Label('Perc. pixel manipulation', layout={'width':'200px'}), float_n2v_perc_pix]),
widgets.HBox([widgets.Label('Model name', layout={'width':'200px'}), text_n2v_name]),
])
display(n2v_parameter)
......@@ -236,34 +253,32 @@ def train_predict(n_tiles=(1,4,4), params=params):
print("Loading images ...")
imgs = datagen.load_imgs()
print("Training ...")
for c in params["train_channels"]:
print("Training channel {}".format(c))
print(" -- Channel {}".format(c))
img_ch = [im[..., c:c+1] for im in imgs]
npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=False)
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=params['augment'])
sep = int(len(patches)*0.9)
X = patches[:sep]
X_val = patches[ sep:]
print("X.shape", X.shape)
print("img_ch.shape", [iii.shape for iii in img_ch])
config = N2VConfig(X,
unet_kern_size=3,
train_steps_per_epoch=params["train_steps_per_epoch"],
train_epochs=params["train_epochs"],
train_loss='mse',
batch_norm=True,
train_batch_size=params["train_batch_size"],
n2v_perc_pix=params["n2v_perc_pix"],
n2v_patch_shape=params['patch_size'],
n2v_manipulator='uniform_withCP',
n2v_neighborhood_radius=params["n2v_neighborhood_radius"])
unet_kern_size=3,
train_steps_per_epoch=params["train_steps_per_epoch"],
train_epochs=params["train_epochs"],
train_loss='mse',
batch_norm=True,
train_batch_size=params["train_batch_size"],
n2v_perc_pix=params["n2v_perc_pix"],
n2v_patch_shape=params['patch_size'],
n2v_manipulator='uniform_withCP',
n2v_neighborhood_radius=params["n2v_neighborhood_radius"])
# a name used to identify the model
......@@ -275,28 +290,41 @@ def train_predict(n_tiles=(1,4,4), params=params):
history = model.train(X, X_val)
val_patch = X_val[0,..., 0]
val_patch_pred = model.predict(val_patch, axes=params["axes"])
f, ax = plt.subplots(1,2, figsize=(14,7))
ax[0].imshow(val_patch[0, ...],cmap='gray')
if "Z" in params["axes"]:
val_patch = val_patch.max(0)
val_patch_pred = val_patch_pred.max(0)
ax[0].imshow(val_patch,cmap='gray')
ax[0].set_title('Validation Patch')
ax[1].imshow(val_patch_pred[0, ...],cmap='gray')
ax[1].imshow(val_patch_pred,cmap='gray')
ax[1].set_title('Validation Patch N2V')
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'])
print("Predicting channel {}".format(c))
print(" -- Predicting channel {}".format(c))
for f, im in zip(files, img_ch):
print(" -- {}".format(f))
pixel_reso = get_space_time_resolution(str(f))
res_img = []
for t in range(len(im)):
pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=n_tiles)
pred = pred[:, None, ...]
nt = n_tiles if "Z" in params["axes"] else n_tiles[1:]
print("im .shape", im.shape)
pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=nt)
print("pred.shape", pred.shape)
if "Z" in params["axes"]:
pred = pred[:, None, ...]
res_img.append(pred)
pred = numpy.stack(res_img)
if "Z" not in params["axes"]:
pred = pred[:, None, None, ...]
reso = (1 / pixel_reso.X, 1 / pixel_reso.Y )
spacing = pixel_reso.Z
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment