Commit 0108fc0f authored by Christoph Sommer's avatar Christoph Sommer

ws

parent 4a57586f
......@@ -29,7 +29,7 @@ tf.logging.set_verbosity(tf.logging.ERROR)
if type(tf.contrib) != type(tf): tf.contrib._warning = None
class BifCareInputConverter(object):
def __init__(self, **params):
def __init__(self, **params):
self.order = 0
self.__dict__.update(**params)
......@@ -51,8 +51,8 @@ class BifCareInputConverter(object):
else:
print(" -- Error: Pixel-type not supported. Pixel type must be 8- or 16-bit")
return
series = 0
series = 0
z_size = reader.getSizeZ()
y_size = reader.getSizeY()
x_size = reader.getSizeX()
......@@ -61,7 +61,7 @@ class BifCareInputConverter(object):
t_size = reader.getSizeT()
for t in range(t_size):
img_3d = numpy.zeros((z_size, c_size, y_size, x_size), dtype=dtype)
for z in range(z_size):
for c in range(c_size):
......@@ -71,22 +71,22 @@ class BifCareInputConverter(object):
c=c, rescale=False)
tmp_dir = pathlib.Path(self.out_dir) / "train_data" / "raw"
for c in range(c_size):
low_dir = tmp_dir / "CH_{}".format(c) / conv_token
low_dir.mkdir(parents=True, exist_ok=True)
out_tif = low_dir / "training_file_{:04d}_t{:04d}.tif".format(f_i, t)
img_3d_ch = img_3d[:, c, :, :]
if conv_scaling:
img_3d_ch = rescale(img_3d_ch, conv_scaling, preserve_range=True,
order=self.order,
img_3d_ch = rescale(img_3d_ch, conv_scaling, preserve_range=True,
order=self.order,
multichannel=False,
mode="reflect",
anti_aliasing=True)
tifffile.imsave(out_tif, img_3d_ch[:, None, :, :].astype(dtype),
tifffile.imsave(out_tif, img_3d_ch[:, None, :, :].astype(dtype),
imagej=True,
metadata={'axes': 'ZCYX'})
ir.close()
......@@ -101,9 +101,9 @@ class BifCareInputConverter(object):
print("Done")
class BifCareTrainer(object):
def __init__(self, **params):
def __init__(self, **params):
self.order = 0
self.__dict__.update(**params)
self.__dict__.update(**params)
def create_patches(self):
for ch in self.train_channels:
......@@ -125,15 +125,15 @@ class BifCareTrainer(object):
)
plt.figure(figsize=(16,4))
rand_sel = numpy.random.randint(low=0, high=len(X), size=6)
plot_some(X[rand_sel, 0],Y[rand_sel, 0],title_list=[range(6)], cmap="gray")
plt.show()
print("Done")
return
def get_training_patch_path(self):
return pathlib.Path(self.out_dir) / 'train_data' / 'patches'
......@@ -155,7 +155,7 @@ class BifCareTrainer(object):
config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs,
train_steps_per_epoch=self.train_steps_per_epoch,
train_batch_size=self.train_batch_size,
**config_args)
**config_args,)
# Training
model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models')
......@@ -174,15 +174,15 @@ class BifCareTrainer(object):
_P = model.keras_model.predict(X_val[:5])
plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
plt.suptitle('5 example validation patches\n'
'top row: input (source), '
plt.suptitle('5 example validation patches\n'
'top row: input (source), '
'middle row: target (ground truth), '
'bottom row: predicted from source');
plt.show()
plt.show()
print("-- Export model for use in Fiji...")
model.export_TF()
model.export_TF()
print("Done")
......@@ -213,8 +213,8 @@ class BifCareTrainer(object):
else:
print("Error: Pixel-type not supported. Pixel type must be 8- or 16-bit")
return
series = 0
series = 0
z_size = reader.getSizeZ()
y_size = reader.getSizeY()
x_size = reader.getSizeX()
......@@ -227,7 +227,7 @@ class BifCareTrainer(object):
if c_size != len(self.train_channels):
print(" -- Warning: Number of Channels during training and prediction do not match. Using channels {} for prediction".format(self.train_channels))
for ch in self.train_channels:
model = CARE(None, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models')
res_image_ch = numpy.zeros(shape=(t_size, z_out_size, 1, y_out_size, x_out_size), dtype=dtype)
......@@ -237,11 +237,11 @@ class BifCareTrainer(object):
for z in range(z_size):
img_3d[z, :, :] = ir.read(series=series,
z=z,
c=ch,
c=ch,
t=t, rescale=False)
img_3d_ch_ex = rescale(img_3d, self.low_scaling, preserve_range=True,
order=self.order,
img_3d_ch_ex = rescale(img_3d, self.low_scaling, preserve_range=True,
order=self.order,
multichannel=False,
mode="reflect",
anti_aliasing=True)
......@@ -252,21 +252,21 @@ class BifCareTrainer(object):
pred = pred.clip(di.min, di.max).astype(dtype)
res_image_ch[t, :, 0, :, :] = pred
if False:
ch_t_out_fn = os.path.join(os.path.dirname(file_fn), os.path.splitext(os.path.basename(file_fn))[0] + "_care_predict_tp{:04d}_ch{}.tif".format(t, ch))
print("Saving time-point {} and channel {} to file '{}'".format(t, ch, ch_t_out_fn))
tifffile.imsave(ch_t_out_fn, pred[None,:, None, :, :], imagej=True, metadata={'axes': 'TZCYX'})
ch_out_fn = os.path.join(os.path.dirname(file_fn),
os.path.splitext(os.path.basename(file_fn))[0]
ch_out_fn = os.path.join(os.path.dirname(file_fn),
os.path.splitext(os.path.basename(file_fn))[0]
+ "_care_predict_ch{}.tif".format(ch))
print(" -- Saving channel {} CARE prediction to file '{}'".format(ch, ch_out_fn))
if keep_meta:
reso = (1 / (pixel_reso.X / self.low_scaling[2]),
reso = (1 / (pixel_reso.X / self.low_scaling[2]),
1 / (pixel_reso.Y / self.low_scaling[1]))
spacing = pixel_reso.Z / self.low_scaling[0]
unit = pixel_reso.Xunit
......@@ -274,7 +274,7 @@ class BifCareTrainer(object):
tifffile.imsave(ch_out_fn, res_image_ch, imagej=True, resolution=reso, metadata={'axes' : 'TZCYX',
'finterval': finterval,
'spacing' : spacing,
'spacing' : spacing,
'unit' : unit})
else:
tifffile.imsave(ch_out_fn, res_image_ch)
......@@ -282,7 +282,7 @@ class BifCareTrainer(object):
res_image_ch = None # should trigger gc and free the memory
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment