Commit 4efd92d6 authored by Christoph Sommer's avatar Christoph Sommer

reading of img with bioformats / added n_patch_per_img

parent 31d32b99
......@@ -11,10 +11,10 @@ from IPython.display import display
from bif_care.qt_file_dialog import gui_fname
from bif_care.qt_dir_dialog import gui_dirname
from bif_care.qt_filesave_dialog import gui_fsavename
from bif_care.qt_files_dialog import gui_fnames
from bif_care.utils import get_pixel_dimensions, get_space_time_resolution, get_file_list
from bif_care.qt_filesave_dialog import gui_fsavename
from bif_care.gui import GuiParams, select_project, select_train_paramter
from bif_care.utils import get_pixel_dimensions, get_space_time_resolution, get_file_list
class GuiParamsN2V(GuiParams):
def initialize(self):
......@@ -24,6 +24,7 @@ class GuiParamsN2V(GuiParams):
self["glob"] = ""
self["axes"] = "ZYX"
self['patch_size'] = []
self['n_patches_per_image'] = -1
self["train_channels"] = [0]
self['train_epochs'] = 40
self['train_steps_per_epoch'] = 100
......@@ -31,11 +32,12 @@ class GuiParamsN2V(GuiParams):
self['n2v_perc_pix'] = 0.016
self['n2v_patch_shape'] = []
self['n2v_neighborhood_radius'] = 5
params = GuiParamsN2V()
params.initialize()
# params.load("H:/projects/024_care_bif/n2v/test/bif_n2v.json")
params.load("C:/Users/csommer/Desktop/bif_n2v2.json")
select_project = partial(select_project, default_name='./bif_n2v.json', params=params)
select_train_paramter = partial(select_train_paramter, params=params)
......@@ -148,7 +150,8 @@ def select_patch_parameter():
patch_size_select = []
patch_options = [8, 16, 32, 64, 128, 256]
params['patch_size'] = [64]*len(list(params["axes"]))
if len(params['patch_size']) == 0:
params['patch_size'] = [64]*len(list(params["axes"]))
for j, a in enumerate(list(params["axes"])):
wi = widgets.Dropdown(options=list(map(str, patch_options)),
......@@ -168,6 +171,18 @@ def select_patch_parameter():
display(widgets.HBox([widgets.Label('Patch size', layout={'width':'100px'}), patch_size_select]))
def select_npatch_per_image():
dd_n_patch_per_img = widgets.BoundedIntText(min=-1, max=4096*2,step=1,
value=params['n_patches_per_image'])
def on_n_patch_per_img_change(change):
params['n_patches_per_image'] = change.new
dd_n_patch_per_img.observe(on_n_patch_per_img_change, 'value')
display(widgets.HBox([widgets.Label('#Patches per image', layout={'width':'100px'}), dd_n_patch_per_img]))
def select_n2v_parameter():
### N2V neighbor radius
###################
......@@ -214,26 +229,32 @@ def train_predict(n_tiles=(1,4,4), params=params):
from n2v.utils.n2v_utils import manipulate_val_data
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
from matplotlib import pyplot as plt
from bif_care.utils import BFListReader
np = numpy
datagen = N2V_DataGenerator()
datagen = BFListReader(params["in_dir"], params["glob"])
files = datagen.img_fns
files = glob.glob(os.path.join(params["in_dir"], params["glob"]))
imgs = datagen.load_imgs(files, dims=params["axes"])
imgs = datagen.load_imgs()
patches = datagen.generate_patches_from_list(imgs, shape=params['patch_size'])
for c in params["train_channels"]:
print("Training channel {}".format(c))
sep = int(len(patches)*0.9)
X = patches[:sep]
X_val = patches[ sep:]
img_ch = [im[..., c:c+1] for im in imgs]
npatches = params["n_patches_per_image"] if params["n_patches_per_image"] > 1 else None
print("npatches", npatches)
patches = N2V_DataGenerator().generate_patches_from_list(img_ch, num_patches_per_img=npatches, shape=params['patch_size'])
for c in params["train_channels"]:
print("Training channel {}".format(c))
sep = int(len(patches)*0.9)
X = patches[:sep]
X_val = patches[ sep:]
print("X.shape", X.shape)
print(img_ch[0].shape, img_ch[0].shape)
config = N2VConfig(X[...,c:c+1],
config = N2VConfig(X,
unet_kern_size=3,
train_steps_per_epoch=params["train_steps_per_epoch"],
train_epochs=params["train_epochs"],
......@@ -253,12 +274,15 @@ def train_predict(n_tiles=(1,4,4), params=params):
# We are now creating our network model.
model = N2V(config=config, name=model_name, basedir=params["in_dir"])
history = model.train(X[...,c:c+1], X_val[...,c:c+1])
history = model.train(X, X_val)
val_patch = X_val[0,...,c]
val_patch = X_val[0,..., 0]
print("val_patch.shape", val_patch.shape)
val_patch_pred = model.predict(val_patch,axes=params["axes"])
print("val_patch_pred.shape", val_patch_pred.shape)
# Let's look at two patches.
f, ax = plt.subplots(1,2, figsize=(14,7))
......@@ -272,7 +296,24 @@ def train_predict(n_tiles=(1,4,4), params=params):
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'])
for f, im in zip(files, imgs):
for f, im in zip(files, img_ch):
print("Predicting {}".format(f))
pred = model.predict(im[0, ..., c], axes=params["axes"], n_tiles=n_tiles)
tifffile.imsave("{}_n2v_pred_ch{}.tif".format(f[:-4], c), pred, imagej=True)
pixel_reso = get_space_time_resolution(str(f))
res_img = []
for t in range(len(im)):
pred = model.predict(im[t,..., 0], axes=params["axes"], n_tiles=n_tiles)
pred = pred[:, None, ...]
res_img.append(pred)
pred = numpy.stack(res_img)
reso = (1 / pixel_reso.X,
1 / pixel_reso.Y )
spacing = pixel_reso.Z
unit = pixel_reso.Xunit
finterval = pixel_reso.T
tifffile.imsave("{}_n2v_pred_ch{}.tiff".format(str(f)[:-4], c), pred, imagej=True, resolution=reso, metadata={'axes': 'TZCYX',
'finterval': finterval,
'spacing' : spacing,
'unit' : unit})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment