Commit 46985f99 authored by Christoph Sommer's avatar Christoph Sommer

added probabilistic CARE (laplacian) as option

parent 1e838d74
......@@ -58,6 +58,7 @@ class GuiParams(dict):
self['train_epochs'] = 40
self['train_steps_per_epoch'] = 100
self['train_batch_size'] = 16
self['probabilistic'] = False
......@@ -279,6 +280,20 @@ def select_patch_parameter():
display(patch_parameter)
### Probabilistic
###################
def select_probabilistic(params=params):
dd_train_proba = widgets.Dropdown(options=[False, True], value=params['probabilistic'])
def on_dd_train_proba_change(change):
params['probabilistic'] = change.new
dd_train_proba.observe(on_dd_train_proba_change, 'value')
probab = widgets.VBox([widgets.HBox([widgets.Label('Probabilistic', layout={'width':'200px'}), dd_train_proba]),])
display(probab)
### Train parameter
###################
......
......@@ -160,6 +160,7 @@ class CareTrainer(object):
config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs,
train_steps_per_epoch=self.train_steps_per_epoch,
train_batch_size=self.train_batch_size,
probabilistic=self.probabilistic,
**config_args,)
# Training
model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models')
......@@ -239,7 +240,13 @@ class CareTrainer(object):
for ch in self.train_channels:
model = CARE(None, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models')
res_image_ch = numpy.zeros(shape=(t_size, z_out_size, 1, y_out_size, x_out_size), dtype=dtype)
out_channels = 1
if self.probabilistic:
out_channels = 2
res_image_ch = numpy.zeros(shape=(t_size, z_out_size, out_channels, y_out_size, x_out_size), dtype=dtype)
print(" -- Predicting channel {}".format(ch))
for t in tqdm(range(t_size), total=t_size):
img_3d = numpy.zeros((z_size, y_size, x_size), dtype=dtype)
......@@ -255,12 +262,19 @@ class CareTrainer(object):
mode="reflect",
anti_aliasing=True)
pred = model.predict(img_3d_ch_ex, axes='ZYX', n_tiles=n_tiles)
if not self.probabilistic:
pred = model.predict(img_3d_ch_ex, axes='ZYX', n_tiles=n_tiles)
di = numpy.iinfo(dtype)
pred = pred.clip(di.min, di.max).astype(dtype)
res_image_ch[t, :, 0, :, :] = pred
else:
# probabilistic
pred = model.predict_probabilistic( img_3d_ch_ex, axes='ZYX', n_tiles=n_tiles)
di = numpy.float32
di = numpy.iinfo(dtype)
pred = pred.clip(di.min, di.max).astype(dtype)
res_image_ch[t, :, 0, :, :] = pred.mean().astype(dtype)
res_image_ch[t, :, 1, :, :] = pred.scale().astype(dtype)
res_image_ch[t, :, 0, :, :] = pred
if False:
ch_t_out_fn = os.path.join(os.path.dirname(file_fn), os.path.splitext(os.path.basename(file_fn))[0] + "_care_predict_tp{:04d}_ch{}.tif".format(t, ch))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment