Source code for smlmlp.modules.block_LP._functions.globdetection.globdet_local_maxima

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet

from smlmlp import block, detect_spatial_maxima
import numpy as np


[docs] @block() def globdet_local_maxima( snrs, /, snr_thresh, channels_spatial_kernels, globdet_groups=None, *, f0=0, channels_pixels_nm=1.0, cuda=False, parallel=False, ): """Detect local maxima on global channels and map them to input channels.""" n_globdet_channels = len(snrs) n_input_channels = len(channels_spatial_kernels) globdet_groups = _normalize_globdet_groups( globdet_groups, n_globdet_channels, n_input_channels, ) group_kernels = [_group_kernel(channels_spatial_kernels, group) for group in globdet_groups] ref_pixel = _reference_pixel(channels_pixels_nm, n_input_channels) fr, x, y, ch, info = detect_spatial_maxima( snrs, snr_thresh, group_kernels, f0=f0, channels_pixels_nm=[ref_pixel for _ in range(n_globdet_channels)], cuda=cuda, parallel=parallel, ) F, X, Y, C, P = [], [], [], [], [] pnt_counter = 0 for group_position, group in enumerate(globdet_groups): group_mask = ch == group_position + 1 fr_group = fr[group_mask] x_group = x[group_mask] y_group = y[group_mask] pnt_group = np.arange( pnt_counter + 1, pnt_counter + len(fr_group) + 1, dtype=np.uint64, ) pnt_counter += len(fr_group) for channel_index in group: F.append(fr_group.copy()) X.append(x_group.copy()) Y.append(y_group.copy()) C.append(np.full(fr_group.shape, channel_index + 1, dtype=np.uint8)) P.append(pnt_group.copy()) if len(F) == 0: fr_out = np.empty(0, dtype=np.int32) x_out = np.empty(0, dtype=np.float32) y_out = np.empty(0, dtype=np.float32) ch_out = np.empty(0, dtype=np.uint8) pnt_out = np.empty(0, dtype=np.uint64) else: fr_out = np.hstack(F) x_out = np.hstack(X) y_out = np.hstack(Y) ch_out = np.hstack(C) pnt_out = np.hstack(P) argsort = np.lexsort((x_out, y_out, fr_out, ch_out)) fr_out = fr_out[argsort] x_out = x_out[argsort] y_out = y_out[argsort] ch_out = ch_out[argsort] pnt_out = pnt_out[argsort] info["globdet_groups"] = [list(group) for group in globdet_groups] info["ref_pixel_nm"] = ref_pixel return fr_out, x_out, y_out, ch_out, pnt_out, info
def _normalize_globdet_groups(globdet_groups, n_globdet_channels, n_input_channels): """Normalize group definitions into 0-based input-channel indices.""" if globdet_groups is None: if n_globdet_channels == n_input_channels: return [[index] for index in range(n_input_channels)] if n_globdet_channels == 1: return [list(range(n_input_channels))] raise ValueError("globdet_groups is required when outputs do not match input channels") if len(globdet_groups) != n_globdet_channels: raise ValueError("globdet_groups does not have the same length as global channels") normalized = [] for group in globdet_groups: if len(group) == 0: raise ValueError("globdet_groups cannot contain empty groups") indices = [int(index) for index in group] if min(indices) < 0 or max(indices) >= n_input_channels: raise ValueError("globdet_groups contains channel indices out of bounds") normalized.append(indices) return normalized def _group_kernel(kernels, group): """Build one detection kernel for a merged group.""" selected = [np.asarray(kernels[index], dtype=np.float32) for index in group] same_shape = all(kernel.shape == selected[0].shape for kernel in selected) if same_shape: return np.mean(np.stack(selected, axis=0), axis=0, dtype=np.float32) return selected[0] def _reference_pixel(channels_pixels_nm, n_input_channels): """Return the common global pixel size in nanometers.""" try: n_pixels = len(channels_pixels_nm) except TypeError: pixels = [(float(channels_pixels_nm), float(channels_pixels_nm)) for _ in range(n_input_channels)] else: try: len(channels_pixels_nm[0]) except TypeError: if n_pixels == 2: pixels = [tuple(float(v) for v in channels_pixels_nm) for _ in range(n_input_channels)] elif n_pixels == n_input_channels: pixels = [(float(v), float(v)) for v in channels_pixels_nm] else: raise ValueError("channels_pixels_nm does not have the same length as channels") else: if n_pixels != n_input_channels: raise ValueError("channels_pixels_nm does not have the same length as channels") pixels = [tuple(float(v) for v in pix) for pix in channels_pixels_nm] return (min([pix[0] for pix in pixels]), min([pix[1] for pix in pixels]))