Commit 6fe1bea6 authored by Christoph Sommer's avatar Christoph Sommer

added timer

parent 76c50dd3
......@@ -230,7 +230,17 @@ class BFListReader(object):
class Timer(object):
def __init__(self, name=None):
self.name = name
def __enter__(self):
self.tstart = time.time()
def __exit__(self, type, value, traceback):
if self.name:
print('[%s]' % self.name,)
print('Elapsed: %s sec.' % (time.time() - self.tstart))
......
import os
import json
import glob
import time
import numpy
import pathlib
import tifffile
......@@ -13,7 +14,7 @@ from bif_care.qt_file_dialog import gui_fname
from bif_care.qt_dir_dialog import gui_dirname
from bif_care.qt_files_dialog import gui_fnames
from bif_care.qt_filesave_dialog import gui_fsavename
from bif_care.utils import get_pixel_dimensions, get_space_time_resolution, get_file_list
from bif_care.utils import get_pixel_dimensions, get_space_time_resolution, get_file_list, Timer
from bif_care.gui import GuiParams, select_project, select_train_paramter, select_file_to_predict
class GuiParamsN2V(GuiParams):
......@@ -183,7 +184,7 @@ def select_patch_parameter():
def select_npatch_per_image():
dd_n_patch_per_img = widgets.BoundedIntText(min=-1, max=4096*2,step=1,
dd_n_patch_per_img = widgets.BoundedIntText(min=-1, max=4096*10, step=1,
value=params['n_patches_per_image'])
def on_n_patch_per_img_change(change):
......@@ -264,85 +265,87 @@ def train_predict(n_tiles=(1,4,4), params=params):
print("Loading images ...")
imgs = datagen.load_imgs()
print("Training ...")
for c in params["train_channels"]:
print(" -- Channel {}".format(c))
img_ch = [im[..., c:c+1] for im in imgs]
npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=params['augment'])
sep = int(len(patches)*0.9)
X = patches[:sep]
X_val = patches[ sep:]
config = N2VConfig(X,
unet_kern_size=3,
train_steps_per_epoch=params["train_steps_per_epoch"],
train_epochs=params["train_epochs"],
train_loss='mse',
batch_norm=True,
train_batch_size=params["train_batch_size"],
n2v_perc_pix=params["n2v_perc_pix"],
n2v_patch_shape=params['patch_size'],
n2v_manipulator='uniform_withCP',
n2v_neighborhood_radius=params["n2v_neighborhood_radius"])
# a name used to identify the model
model_name = '{}_ch{}'.format(params['name'], c)
# the base directory in which our model will live
basedir = 'models'
# We are now creating our network model.
model = N2V(config=config, name=model_name, basedir=params["in_dir"])
history = model.train(X, X_val)
val_patch = X_val[0,..., 0]
val_patch_pred = model.predict(val_patch, axes=params["axes"])
f, ax = plt.subplots(1,2, figsize=(14,7))
if "Z" in params["axes"]:
val_patch = val_patch.max(0)
val_patch_pred = val_patch_pred.max(0)
ax[0].imshow(val_patch,cmap='gray')
ax[0].set_title('Validation Patch')
ax[1].imshow(val_patch_pred,cmap='gray')
ax[1].set_title('Validation Patch N2V')
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'])
print(" -- Predicting channel {}".format(c))
for f, im in zip(files, img_ch):
print(" -- {}".format(f))
pixel_reso = get_space_time_resolution(str(f))
res_img = []
for t in range(len(im)):
nt = n_tiles if "Z" in params["axes"] else n_tiles[1:]
pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=nt)
if "Z" in params["axes"]:
pred = pred[:, None, ...]
res_img.append(pred)
pred = numpy.stack(res_img)
if "Z" not in params["axes"]:
pred = pred[:, None, None, ...]
reso = (1 / pixel_reso.X, 1 / pixel_reso.Y )
spacing = pixel_reso.Z
unit = pixel_reso.Xunit
finterval = pixel_reso.T
tifffile.imsave("{}_n2v_pred_ch{}.tiff".format(str(f)[:-4], c), pred, imagej=True, resolution=reso, metadata={'axes': 'TZCYX',
'finterval': finterval,
'spacing' : spacing,
'unit' : unit})
with Timer('Training and Prediction'):
print("Training ...")
for c in params["train_channels"]:
print(" -- Channel {}".format(c))
img_ch = [im[..., c:c+1] for im in imgs]
npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'], augment=params['augment'])
sep = int(len(patches)*0.9)
X = patches[:sep]
X_val = patches[ sep:]
config = N2VConfig(X,
unet_kern_size=3,
train_steps_per_epoch=params["train_steps_per_epoch"],
train_epochs=params["train_epochs"],
train_loss='mse',
batch_norm=True,
train_batch_size=params["train_batch_size"],
n2v_perc_pix=params["n2v_perc_pix"],
n2v_patch_shape=params['patch_size'],
n2v_manipulator='uniform_withCP',
n2v_neighborhood_radius=params["n2v_neighborhood_radius"])
# a name used to identify the model
model_name = '{}_ch{}'.format(params['name'], c)
# the base directory in which our model will live
basedir = 'models'
# We are now creating our network model.
model = N2V(config=config, name=model_name, basedir=params["in_dir"])
history = model.train(X, X_val)
val_patch = X_val[0,..., 0]
val_patch_pred = model.predict(val_patch, axes=params["axes"])
f, ax = plt.subplots(1,2, figsize=(14,7))
if "Z" in params["axes"]:
val_patch = val_patch.max(0)
val_patch_pred = val_patch_pred.max(0)
ax[0].imshow(val_patch,cmap='gray')
ax[0].set_title('Validation Patch')
ax[1].imshow(val_patch_pred,cmap='gray')
ax[1].set_title('Validation Patch N2V')
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'])
print(" -- Predicting channel {}".format(c))
for f, im in zip(files, img_ch):
print(" -- {}".format(f))
pixel_reso = get_space_time_resolution(str(f))
res_img = []
for t in range(len(im)):
nt = n_tiles if "Z" in params["axes"] else n_tiles[1:]
pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=nt)
if "Z" in params["axes"]:
pred = pred[:, None, ...]
res_img.append(pred)
pred = numpy.stack(res_img)
if "Z" not in params["axes"]:
pred = pred[:, None, None, ...]
reso = (1 / pixel_reso.X, 1 / pixel_reso.Y )
spacing = pixel_reso.Z
unit = pixel_reso.Xunit
finterval = pixel_reso.T
tifffile.imsave("{}_n2v_pred_ch{}.tiff".format(str(f)[:-4], c), pred, imagej=True, resolution=reso, metadata={'axes': 'TZCYX',
'finterval': finterval,
'spacing' : spacing,
'unit' : unit})
def predict(files, n_tiles=(1,4,4), params=params):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment