#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author : Lancelot PINCET
# GitHub : https://github.com/LancelotPincet
from smlmlp import block, detect_spatial_maxima
import numpy as np
[docs]
@block()
def globdet_local_maxima(
snrs,
/,
snr_thresh,
channels_spatial_kernels,
globdet_groups=None,
*,
f0=0,
channels_pixels_nm=1.0,
cuda=False,
parallel=False,
):
"""Detect local maxima on global channels and map them to input channels."""
n_globdet_channels = len(snrs)
n_input_channels = len(channels_spatial_kernels)
globdet_groups = _normalize_globdet_groups(
globdet_groups,
n_globdet_channels,
n_input_channels,
)
group_kernels = [_group_kernel(channels_spatial_kernels, group) for group in globdet_groups]
ref_pixel = _reference_pixel(channels_pixels_nm, n_input_channels)
fr, x, y, ch, info = detect_spatial_maxima(
snrs,
snr_thresh,
group_kernels,
f0=f0,
channels_pixels_nm=[ref_pixel for _ in range(n_globdet_channels)],
cuda=cuda,
parallel=parallel,
)
F, X, Y, C, P = [], [], [], [], []
pnt_counter = 0
for group_position, group in enumerate(globdet_groups):
group_mask = ch == group_position + 1
fr_group = fr[group_mask]
x_group = x[group_mask]
y_group = y[group_mask]
pnt_group = np.arange(
pnt_counter + 1,
pnt_counter + len(fr_group) + 1,
dtype=np.uint64,
)
pnt_counter += len(fr_group)
for channel_index in group:
F.append(fr_group.copy())
X.append(x_group.copy())
Y.append(y_group.copy())
C.append(np.full(fr_group.shape, channel_index + 1, dtype=np.uint8))
P.append(pnt_group.copy())
if len(F) == 0:
fr_out = np.empty(0, dtype=np.int32)
x_out = np.empty(0, dtype=np.float32)
y_out = np.empty(0, dtype=np.float32)
ch_out = np.empty(0, dtype=np.uint8)
pnt_out = np.empty(0, dtype=np.uint64)
else:
fr_out = np.hstack(F)
x_out = np.hstack(X)
y_out = np.hstack(Y)
ch_out = np.hstack(C)
pnt_out = np.hstack(P)
argsort = np.lexsort((x_out, y_out, fr_out, ch_out))
fr_out = fr_out[argsort]
x_out = x_out[argsort]
y_out = y_out[argsort]
ch_out = ch_out[argsort]
pnt_out = pnt_out[argsort]
info["globdet_groups"] = [list(group) for group in globdet_groups]
info["ref_pixel_nm"] = ref_pixel
return fr_out, x_out, y_out, ch_out, pnt_out, info
def _normalize_globdet_groups(globdet_groups, n_globdet_channels, n_input_channels):
"""Normalize group definitions into 0-based input-channel indices."""
if globdet_groups is None:
if n_globdet_channels == n_input_channels:
return [[index] for index in range(n_input_channels)]
if n_globdet_channels == 1:
return [list(range(n_input_channels))]
raise ValueError("globdet_groups is required when outputs do not match input channels")
if len(globdet_groups) != n_globdet_channels:
raise ValueError("globdet_groups does not have the same length as global channels")
normalized = []
for group in globdet_groups:
if len(group) == 0:
raise ValueError("globdet_groups cannot contain empty groups")
indices = [int(index) for index in group]
if min(indices) < 0 or max(indices) >= n_input_channels:
raise ValueError("globdet_groups contains channel indices out of bounds")
normalized.append(indices)
return normalized
def _group_kernel(kernels, group):
"""Build one detection kernel for a merged group."""
selected = [np.asarray(kernels[index], dtype=np.float32) for index in group]
same_shape = all(kernel.shape == selected[0].shape for kernel in selected)
if same_shape:
return np.mean(np.stack(selected, axis=0), axis=0, dtype=np.float32)
return selected[0]
def _reference_pixel(channels_pixels_nm, n_input_channels):
"""Return the common global pixel size in nanometers."""
try:
n_pixels = len(channels_pixels_nm)
except TypeError:
pixels = [(float(channels_pixels_nm), float(channels_pixels_nm)) for _ in range(n_input_channels)]
else:
try:
len(channels_pixels_nm[0])
except TypeError:
if n_pixels == 2:
pixels = [tuple(float(v) for v in channels_pixels_nm) for _ in range(n_input_channels)]
elif n_pixels == n_input_channels:
pixels = [(float(v), float(v)) for v in channels_pixels_nm]
else:
raise ValueError("channels_pixels_nm does not have the same length as channels")
else:
if n_pixels != n_input_channels:
raise ValueError("channels_pixels_nm does not have the same length as channels")
pixels = [tuple(float(v) for v in pix) for pix in channels_pixels_nm]
return (min([pix[0] for pix in pixels]), min([pix[1] for pix in pixels]))