Commit 23fe0131 authored by Christoph Sommer's avatar Christoph Sommer

advanced options as keyword args

parent b8d1e1d4
......@@ -38,7 +38,7 @@ class GuiParamsN2V(GuiParams):
params = GuiParamsN2V()
params.initialize()
#params.load("H:/projects/024_care_bif/n2v/test_TCZYX/bif_n2v.json")
#params.loadn_("J:/_BIF/RH/Noisy MO/bif_n2v.json")
select_project = partial(select_project, default_name='./bif_n2v.json', params=params)
select_train_paramter = partial(select_train_paramter, params=params)
......@@ -248,7 +248,25 @@ def select_n2v_parameter():
display(n2v_parameter)
def train_predict(n_tiles=(1,4,4), params=params):
def train_predict(n_tiles=(1,4,4), params=params, **unet_config):
"""
These advanced options can be set by keyword arguments:
n_tiles : tuple(int)
Number of tiles to tile the image into, if it is too large for memory.
unet_residual : bool
Parameter `residual` of :func:`n2v_old.nets.common_unet`. Default: ``n_channel_in == n_channel_out``
unet_n_depth : int
Parameter `n_depth` of :func:`n2v_old.nets.common_unet`. Default: ``2``
unet_kern_size : int
Parameter `kern_size` of :func:`n2v_old.nets.common_unet`. Default: ``5 if n_dim==2 else 3``
unet_n_first : int
Parameter `n_first` of :func:`n2v_old.nets.common_unet`. Default: ``32``
batch_norm : bool
Activate batch norm
unet_last_activation : str
Parameter `last_activation` of :func:`n2v_old.nets.common_unet`. Default: ``linear``
"""
from n2v.models import N2VConfig, N2V
from csbdeep.utils import plot_history
from n2v.utils.n2v_utils import manipulate_val_data
......@@ -265,6 +283,8 @@ def train_predict(n_tiles=(1,4,4), params=params):
print("Loading images ...")
imgs = datagen.load_imgs()
with Timer('Training and Prediction'):
print("Training ...")
......@@ -282,16 +302,14 @@ def train_predict(n_tiles=(1,4,4), params=params):
X_val = patches[ sep:]
config = N2VConfig(X,
unet_kern_size=3,
train_steps_per_epoch=params["train_steps_per_epoch"],
train_epochs=params["train_epochs"],
train_loss='mse',
batch_norm=True,
train_batch_size=params["train_batch_size"],
n2v_perc_pix=params["n2v_perc_pix"],
n2v_patch_shape=params['patch_size'],
n2v_manipulator='uniform_withCP',
n2v_neighborhood_radius=params["n2v_neighborhood_radius"])
n2v_neighborhood_radius=params["n2v_neighborhood_radius"], **unet_config)
# a name used to identify the model
......
......@@ -96,7 +96,8 @@
"source": [
"## 3. Denoise (train and predict)\n",
"---\n",
"* Images will be tiled according to the **n_tiles** (z,y,x) parameter. (usefull for big images)"
"* Images will be tiled according to the **n_tiles** (z,y,x) parameter. (usefull for big images)\n",
"* For advanced training parameters type and execute *bif_n2v.train_predict?* in an empty code cell"
]
},
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment