Commit 92b5e472 authored by Christoph Sommer's avatar Christoph Sommer

added timer to training

parent 9d905625
......@@ -20,7 +20,7 @@ from csbdeep.data import RawData, create_patches
from csbdeep.utils import axes_dict, plot_some, plot_history
from .utils import JVM, get_file_list, get_pixel_dimensions, \
get_upscale_factors, get_space_time_resolution
get_upscale_factors, get_space_time_resolution, Timer
import warnings
warnings.filterwarnings("ignore")
......@@ -143,45 +143,47 @@ class BifCareTrainer(object):
if channels is None:
channels = self.train_channels
for ch in channels:
print("-- Training channel {}...".format(ch))
(X,Y), (X_val,Y_val), axes = load_training_data(self.get_training_patch_path() / 'CH_{}_training_patches.npz'.format(ch), validation_split=0.1, verbose=False)
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs,
train_steps_per_epoch=self.train_steps_per_epoch,
train_batch_size=self.train_batch_size,
**config_args)
# Training
model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models')
# Show learning curve and example validation results
try:
history = model.train(X,Y, validation_data=(X_val,Y_val))
except tf.errors.ResourceExhaustedError:
print("ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?")
return
with Timer('Training'):
for ch in channels:
print("-- Training channel {}...".format(ch))
(X,Y), (X_val,Y_val), axes = load_training_data(self.get_training_patch_path() / 'CH_{}_training_patches.npz'.format(ch), validation_split=0.1, verbose=False)
c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
config = Config(axes, n_channel_in, n_channel_out, train_epochs=self.train_epochs,
train_steps_per_epoch=self.train_steps_per_epoch,
train_batch_size=self.train_batch_size,
**config_args)
# Training
model = CARE(config, 'CH_{}_model'.format(ch), basedir=pathlib.Path(self.out_dir) / 'models')
# Show learning curve and example validation results
try:
history = model.train(X,Y, validation_data=(X_val,Y_val))
except tf.errors.ResourceExhaustedError:
print("ResourceExhaustedError: Aborting...\n Training data too big for GPU. Are other GPU jobs running? Perhaps, reduce batch-size or patch-size?")
return
#print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae'])
#print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae'])
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
plt.suptitle('5 example validation patches\n'
'top row: input (source), '
'middle row: target (ground truth), '
'bottom row: predicted from source');
plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5, cmap="gray")
plt.suptitle('5 example validation patches\n'
'top row: input (source), '
'middle row: target (ground truth), '
'bottom row: predicted from source');
plt.show()
plt.show()
print("-- Export model for use in Fiji...")
model.export_TF()
print("Done")
print("-- Export model for use in Fiji...")
model.export_TF()
print("Done")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment