Commit 47640a15 authored by Christoph Sommer's avatar Christoph Sommer
Browse files

added prob map thresholding

parent 850d3782
......@@ -5,7 +5,7 @@ import numpy
import logging
import argparse
import networkx as nx
from skimage import morphology, measure
from skimage import morphology, measure, filters
def build_branch_graph(skel_branches):
......@@ -48,9 +48,22 @@ def extract_pos_3d(skan_branches, scale=2):
return pos_dict
def read_ilastik_seg(ilastik_seg_fn):
with h5py.File(ilastik_seg_fn, "r") as hf:
return hf["exported_data"][()][0, ..., 0]
def segment_prob_map(ilastik_fn, sigma, thresh):
with h5py.File(ilastik_fn, "r") as hf:
prob_map = hf["exported_data"][()][0, ..., 0] # time point 0
if prob_map.dtype != numpy.uint8:
logging.warn(
f" Probability map seems to have wrong pixel type. Expected uint8. got {prob_map.dtype}"
)
assert (
len(prob_map.shape) == 3
), f"Wrong dimensions. Expected 3D, got shape: '{prob_map.shape}'"
img = filters.gaussian(prob_map, sigma=sigma, preserve_range=True)
seg = measure.label(img > 255 * thresh)
return seg
def remove_small_segments(seg, min_size):
......@@ -58,14 +71,12 @@ def remove_small_segments(seg, min_size):
rp = measure.regionprops(seg)
rp = sorted(rp, key=lambda r: r.area, reverse=True)
logging.info(
f" - Found {len(rp)} segments size=: {','.join([str(r.area) for r in rp])}"
f" - Found {len(rp)} segments sizes (px): {','.join([str(r.area) for r in rp[:10]])} ..."
)
for r in rp:
if r.area < min_size:
seg[r.coords[:, 0], r.coords[:, 1], r.coords[:, 2]] = 0
seg = morphology.remove_small_objects(seg, min_size=min_size)
return measure.label(seg).astype(numpy.uint8)
return measure.label(seg).astype(numpy.uint16)
def skeletonize(seg_binary, vx_size=(1, 1, 1)):
......@@ -151,15 +162,17 @@ def write_swc(fn, swc_table):
fh.writelines(lines)
def run(ilastik_seg_fn, min_size, scale):
logging.info(f"File: {ilastik_seg_fn} ")
logging.info(f"MinSize: {min_size} ")
logging.info(f"Scale: {scale} (reso. level)")
def run(ilastik_seg_fn, min_size, rl, sigma, thresh):
logging.info(f"File: {ilastik_seg_fn}")
logging.info(f"MinSize: {min_size}")
logging.info(f"ResoLev: {rl}")
logging.info(f"Sigma: {sigma}")
logging.info(f"Thresh: {thresh}")
logging.info("-" * 80)
base_fn = os.path.splitext(ilastik_seg_fn)[0]
logging.info(" - Read segmentation")
img_seg = read_ilastik_seg(ilastik_seg_fn)
logging.info(f" - Read probability maps and segment")
img_seg = segment_prob_map(ilastik_seg_fn, sigma, thresh)
img_seg = remove_small_segments(img_seg, min_size)
logging.info(f" - Removed segments smaller {min_size} px")
......@@ -168,7 +181,7 @@ def run(ilastik_seg_fn, min_size, scale):
for seg_id in range(1, 1 + img_seg.max()):
logging.info(f" {seg_id}: Skeletonize")
skel, skel_branches = skeletonize(img_seg == seg_id)
pos_3d = extract_pos_3d(skel_branches, scale)
pos_3d = extract_pos_3d(skel_branches, rl)
logging.info(f" : Build branch graph")
graph_branches = build_branch_graph(skel_branches)
......@@ -191,13 +204,38 @@ def run(ilastik_seg_fn, min_size, scale):
def get_args():
description = """Extract skeletons from ilastik dendtite segmentation and export to .swc for import in Imaris"""
description = """Extract skeletons from ilastik dendrite probability maps and export to .swc for import in Imaris"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument("ilastik_seg_h5", nargs="+", type=str)
parser.add_argument("-ms", "--min_size", type=int, default=10000)
parser.add_argument("-s", "--scale", type=int, default=2)
parser.add_argument(
"ilastik_h5",
nargs="+",
type=str,
help="ilastik probability map (single channel) in 8-bit",
)
parser.add_argument(
"-ms",
"--min_size",
type=int,
default=10000,
help="Minimum object size in pixel",
)
parser.add_argument(
"-rl", "--resolution_level", type=int, default=2, help="Resolution level used",
)
parser.add_argument(
"-s",
"--smooth_sigma",
type=int,
nargs=3,
default=(0.5, 0.5, 0.5),
help="Smooth prob. map before thresholding. Gaussian sigma in px for ZYX",
)
parser.add_argument(
"-t", "--threshold", type=float, default=0.5, help="Probability map threshold",
)
args = parser.parse_args()
return args
......@@ -208,6 +246,12 @@ if __name__ == "__main__":
args = get_args()
for ilastik_fn in args.ilastik_seg_h5:
run(ilastik_fn, args.min_size, args.scale)
for ilastik_fn in args.ilastik_h5:
run(
ilastik_fn,
args.min_size,
args.resolution_level,
args.smooth_sigma,
args.threshold,
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment