Source code for smlmlp.modules.block_LP._functions.detection.detect_spatial_maxima

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet



from smlmlp import block, Config
from arrlp import get_xp, nb_threads
import numba as nb
from numba import cuda as nb_cuda
import numpy as np
import math



[docs] @block() def detect_spatial_maxima( snrs, /, snr_thresh, channels_spatial_kernels, *, f0=0, channels_pixels_nm=1.0, cuda=False, parallel=False, ): """ Detect local spatial maxima in thresholded SNR images. This function detects local maxima in SNR stacks using a channel-specific spatial footprint derived from the provided kernels. Detected maxima are refined to subpixel coordinates through a local center-of-mass estimate, then converted to physical units using the channel pixel sizes. Parameters ---------- snrs : sequence of ndarray Sequence of SNR stacks, one per channel. Each stack is expected to have shape ``(n_frames, height, width)``. snr_thresh : float Detection threshold applied to the SNR values. channels_spatial_kernels : sequence of ndarray Spatial kernels used to define the local-maxima footprint for each channel. f0 : int, optional Frame offset added to the detected frame indices before converting them back to the 1-based convention. channels_pixels_nm : float or sequence, optional Pixel size in nanometers for each channel. This value is normalized through :class:`smlmlp.Config` so that one pixel size pair is available for each channel. cuda : bool, optional Whether to use GPU acceleration. parallel : bool, optional Whether to enable CPU parallelization. Returns ------- tuple A tuple ``(fr, x, y, ch, info)`` where: - ``fr`` is the concatenated array of detected frame indices, - ``x`` is the concatenated array of detected x coordinates in nanometers, - ``y`` is the concatenated array of detected y coordinates in nanometers, - ``ch`` is the concatenated array of detected channel indices, - ``info`` is a dictionary containing reusable intermediate results. The dictionary contains the following keys: ``'footprints'`` Boolean detection footprints derived from the spatial kernels. ``'channels_pixels_nm'`` Normalized per-channel pixel sizes used for coordinate conversion. Notes ----- The footprint is obtained by thresholding each spatial kernel using the Sparrow-limit-based factor: .. math:: \\exp\\left(-\\frac{(0.47 / 0.21)^2}{2}\\right) On CPU, local maxima are first detected on the integer grid and then refined with a 5x5 center-of-mass estimate. On GPU, both operations are combined in a single kernel. Examples -------- >>> import numpy as np >>> snr = np.random.rand(5, 16, 16).astype(np.float32) >>> kernel = np.ones((3, 3), dtype=np.float32) >>> fr, x, y, ch, info = detect_spatial_maxima( ... [snr], ... 0.9, ... [kernel], ... ) >>> fr.ndim == x.ndim == y.ndim == ch.ndim == 1 True >>> "footprints" in info True >>> snrs = [ ... np.random.rand(4, 16, 16).astype(np.float32), ... np.random.rand(4, 16, 16).astype(np.float32), ... ] >>> kernels = [ ... np.ones((3, 3), dtype=np.float32), ... np.ones((5, 5), dtype=np.float32), ... ] >>> fr, x, y, ch, info = detect_spatial_maxima( ... snrs, ... 1.2, ... kernels, ... channels_pixels_nm=[(100.0, 100.0), (110.0, 110.0)], ... ) >>> len(info["footprints"]) 2 """ # Select the array backend matching the requested execution mode. xp = get_xp(cuda) # Build boolean detection footprints from the spatial kernels. k_fp = math.exp(-(0.47 / 0.21) ** 2 / 2) footprints = [ kernel > kernel.max() * k_fp for kernel in channels_spatial_kernels ] # Normalize the pixel size input so that one pixel size pair is available # for each channel. channels_pixels_nm = Config( ncameras=len(snrs), channels_cameras_nm=channels_pixels_nm, ).channels_cameras_nm F, X, Y, C = [], [], [], [] for pos, (snr, footprint, pixel) in enumerate( zip(snrs, footprints, channels_pixels_nm) ): snr = xp.asarray(snr) footprint = xp.asarray(footprint) # GPU implementation: detect maxima and compute the local center of # mass directly in the CUDA kernel. if cuda: max_points = int(snr.size / xp.sum(footprint)) fr_out = xp.empty(max_points, dtype=xp.int32) y_out = xp.empty(max_points, dtype=xp.float32) x_out = xp.empty(max_points, dtype=xp.float32) counter = xp.zeros(1, dtype=xp.int32) shape = snr.shape threads_per_block = (2, 16, 16) blocks_per_grid = ( (shape[0] + threads_per_block[0] - 1) // threads_per_block[0], (shape[1] + threads_per_block[1] - 1) // threads_per_block[1], (shape[2] + threads_per_block[2] - 1) // threads_per_block[2], ) det_gpu[blocks_per_grid, threads_per_block]( snr, snr_thresh, footprint, fr_out, y_out, x_out, counter, ) n = xp.asnumpy(counter)[0] fr = xp.asnumpy(fr_out[:n]) y = xp.asnumpy(y_out[:n]) x = xp.asnumpy(x_out[:n]) # CPU implementation: detect integer-grid maxima first, then refine # them with a local center of mass. else: mask = xp.zeros_like(snr, dtype=xp.bool_) with nb_threads(parallel): maxi_cpu(mask, snr, snr_thresh, footprint) fr, y_int, x_int = xp.nonzero(mask) y = xp.empty_like(y_int, dtype=xp.float32) x = xp.empty_like(x_int, dtype=xp.float32) if len(fr): with nb_threads(parallel): com_cpu(snr, fr, y_int, x_int, y, x) # Convert detections back to the external frame convention and to # physical coordinates, then sort them lexicographically. fr += f0 + 1 y *= pixel[0] x *= pixel[1] c = np.full_like(fr, fill_value=pos + 1, dtype=np.uint8) argsort = np.lexsort((x, y, fr)) F.append(fr[argsort]) Y.append(y[argsort]) X.append(x[argsort]) C.append(c) info = { "footprints": footprints, "channels_pixels_nm": channels_pixels_nm, } return np.hstack(F), np.hstack(X), np.hstack(Y), np.hstack(C), info
@nb.njit(parallel=True, nogil=True, cache=True, fastmath=True) def maxi_cpu(mask, snr, snr_thresh, footprint): """Mark thresholded local maxima on CPU.""" n_frames, height, width = snr.shape fp_height, fp_width = footprint.shape cy, cx = fp_height // 2, fp_width // 2 for fr in nb.prange(n_frames): frame = snr[fr] frame_mask = mask[fr] for y in range(height): for x in range(width): val = frame[y, x] if val < snr_thresh: continue is_max = True for dy in range(fp_height): for dx in range(fp_width): if not footprint[dy, dx]: continue ny = y + dy - cy nx = x + dx - cx if ny < 0 or ny >= height or nx < 0 or nx >= width: continue if frame[ny, nx] > val: is_max = False break if not is_max: break if is_max: frame_mask[y, x] = True @nb.njit(parallel=True, nogil=True, cache=True, fastmath=True) def com_cpu(snr, fr_idx, y_idx, x_idx, y_out, x_out): """Refine detected maxima with a 5x5 center of mass on CPU.""" n = len(fr_idx) _, height, width = snr.shape for i in nb.prange(n): fr = fr_idx[i] y = y_idx[i] x = x_idx[i] xnum = 0.0 ynum = 0.0 denom = 0.0 for dy in range(-2, 3): for dx in range(-2, 3): ny = y + dy nx = x + dx if ny < 0 or ny >= height or nx < 0 or nx >= width: continue val = snr[fr, ny, nx] xnum += nx * val ynum += ny * val denom += val if denom > 0: x_out[i] = xnum / denom y_out[i] = ynum / denom else: x_out[i] = x y_out[i] = y @nb_cuda.jit(fastmath=True, cache=True) def det_gpu(snr, snr_thresh, footprint, fr_out, y_out, x_out, counter): """Detect and refine local maxima on GPU.""" fr, y, x = nb_cuda.grid(3) n_frames, height, width = snr.shape fp_height, fp_width = footprint.shape if fr >= n_frames or y >= height or x >= width: return val = snr[fr, y, x] if val < snr_thresh: return cy = fp_height // 2 cx = fp_width // 2 # Local-maximum test inside the footprint. is_max = True for dy in range(fp_height): for dx in range(fp_width): if not footprint[dy, dx]: continue ny = y + dy - cy nx = x + dx - cx if ny < 0 or ny >= height or nx < 0 or nx >= width: continue if snr[fr, ny, nx] > val: is_max = False break if not is_max: break if not is_max: return # 5x5 center-of-mass refinement. xnum = 0.0 ynum = 0.0 denom = 0.0 for dy in range(-2, 3): for dx in range(-2, 3): ny = y + dy nx = x + dx if ny < 0 or ny >= height or nx < 0 or nx >= width: continue v = snr[fr, ny, nx] xnum += nx * v ynum += ny * v denom += v if denom > 0: xf = xnum / denom yf = ynum / denom else: xf = x yf = y # Atomically append the detection to the output buffers. idx = nb_cuda.atomic.add(counter, 0, 1) fr_out[idx] = fr y_out[idx] = yf x_out[idx] = xf