Source code for smlmlp.modules.block_LP._functions.globdetection.globdet_channels

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet

from smlmlp import block
from arrlp import get_xp, img_transform, transform_matrix
import numpy as np



[docs] @block() def globdet_channels( channels, /, globdet_mode="mean", channels_x_shifts_nm=None, channels_y_shifts_nm=None, channels_rotations_deg=None, channels_x_shears=None, channels_y_shears=None, globdet_groups=None, global_channels=None, *, channels_pixels_nm=1.0, channels_gains=0.25, cuda=False, parallel=False, ): """ Create a global channel for detection. This function geometrically transforms each input channel stack into the global detection frame, then merges all transformed channels into a single channel stack using either a mean or standard deviation projection across channels. Parameters ---------- channels : sequence of ndarray Input channel stacks. globdet_mode : {"mean", "std"}, optional Aggregation used to merge transformed channels. channels_x_shifts_nm, channels_y_shifts_nm : sequence of float or None, optional Per-channel translations in nanometers. If ``None``, zeros are used. channels_rotations_deg : sequence of float or None, optional Per-channel rotations, in degrees. If ``None``, zeros are used. channels_x_shears, channels_y_shears : sequence of float or None, optional Per-channel shear values. If ``None``, zeros are used. globdet_groups : sequence of sequence of int or None, optional Group definition for merged output channels. Each group contains input channel indices. A group index can be 0-based or 1-based. If ``None``, all channels are merged into one output channel. global_channels : sequence of ndarray or None, optional Optional preallocated output list for merged output channels. If provided with larger arrays, centered spatial views are reused when possible. channels_pixels_nm : float or sequence, optional Pixel size in nanometers. Can be scalar, ``(y, x)``, or per-channel. channels_gains : float or sequence, optional Per-channel gains in e-/ADU. Input channels are converted to photons before merging and the merged global channel is converted back to ADU using the average gain. cuda : bool, optional Whether to use CUDA execution. parallel : bool, optional Whether to use parallel execution. Returns ------- tuple A tuple ``(new_channels, info)`` where: - ``new_channels`` is a list containing one merged channel per group, - ``info`` is a dictionary containing reusable intermediate results. The dictionary contains the following keys: ``'channels_pixels_nm'`` Normalized per-channel pixel sizes. ``'ref_pix'`` Reference pixel size used for the common scaling. ``'scales_x'`` Per-channel x scaling factors. ``'scales_y'`` Per-channel y scaling factors. ``'channels_x_shifts_nm'`` Normalized per-channel x translations. ``'channels_y_shifts_nm'`` Normalized per-channel y translations. ``'channels_rotations_deg'`` Normalized per-channel rotations. ``'channels_x_shears'`` Normalized per-channel x shears. ``'channels_y_shears'`` Normalized per-channel y shears. ``'transform_matrices'`` Transform matrices applied to each input channel. ``'crop_shape'`` Common centered crop shape applied after transformation. ``'crop_bboxes'`` Per-channel centered crop boxes as ``(x0, y0, x1, y1)``. Raises ------ ValueError If no channel is provided or if a per-channel parameter length does not match ``len(channels)``. SyntaxError If ``mode`` is not recognized. Examples -------- >>> import numpy as np >>> channels = [ ... np.ones((2, 4, 4), dtype=np.float32), ... np.ones((2, 4, 4), dtype=np.float32) * 3, ... ] >>> global_channels, info = globdet_channels(channels) >>> len(global_channels) 1 >>> global_channels[0].shape (2, 4, 4) >>> global_channels, info = globdet_channels(channels, mode="std") >>> np.allclose(global_channels[0], 1.0) True """ if len(channels) == 0: raise ValueError("channels must contain at least one channel") # Select the array backend matching the requested execution mode. xp = get_xp(cuda) # Normalize the per-channel pixel sizes so shifts can be converted from # nanometers to pixels before building image-space transforms. channels_pixels_nm = _normalize_channels_pixels_nm( channels_pixels_nm, len(channels), ) ref_pix = ( min([pix[0] for pix in channels_pixels_nm]), min([pix[1] for pix in channels_pixels_nm]), ) scales_x = [ref_pix[1] / pix[1] for pix in channels_pixels_nm] scales_y = [ref_pix[0] / pix[0] for pix in channels_pixels_nm] # Initialize and normalize registration parameters to one value per channel. channels_x_shifts_nm = _normalize_channels_parameter( np.zeros(len(channels), dtype=np.float32) if channels_x_shifts_nm is None else channels_x_shifts_nm, len(channels), "channels_x_shifts_nm", ) channels_y_shifts_nm = _normalize_channels_parameter( np.zeros(len(channels), dtype=np.float32) if channels_y_shifts_nm is None else channels_y_shifts_nm, len(channels), "channels_y_shifts_nm", ) channels_rotations_deg = _normalize_channels_parameter( np.zeros(len(channels), dtype=np.float32) if channels_rotations_deg is None else channels_rotations_deg, len(channels), "channels_rotations_deg", ) channels_x_shears = _normalize_channels_parameter( np.zeros(len(channels), dtype=np.float32) if channels_x_shears is None else channels_x_shears, len(channels), "channels_x_shears", ) channels_y_shears = _normalize_channels_parameter( np.zeros(len(channels), dtype=np.float32) if channels_y_shears is None else channels_y_shears, len(channels), "channels_y_shears", ) channels_gains = _normalize_channels_parameter( channels_gains, len(channels), "channels_gains", ) channels_gains = [xp.float32(gain) for gain in channels_gains] globdet_groups = _normalize_globdet_groups(globdet_groups, len(channels)) globdet_channels_gains = [ xp.float32(np.mean([channels_gains[index] for index in group])) for group in globdet_groups ] if any(gain <= 0 for gain in globdet_channels_gains): raise ValueError("Average channels_gains must be strictly positive") match globdet_mode: case "mean": agg_func = xp.mean case "std": agg_func = xp.std case _: raise SyntaxError(f"Aggregation globdet_mode {globdet_mode} is not recognized") transformed = [] valid_masks = [] matrices = [] # Transform each channel stack into the global detection frame. for i in range(len(channels)): channel = xp.asarray(channels[i], dtype=xp.float32) * channels_gains[i] matrix1 = transform_matrix( channel, scalex=scales_x[i], scaley=scales_y[i], stacks=True, ) matrix2 = transform_matrix( channel, shiftx=channels_x_shifts_nm[i] / ref_pix[1], shifty=channels_y_shifts_nm[i] / ref_pix[0], angle=channels_rotations_deg[i], shearx=channels_x_shears[i], sheary=channels_y_shears[i], stacks=True, ) matrix = matrix1 @ matrix2 transformed.append( img_transform( channel, matrix=matrix, cuda=cuda, parallel=parallel, stacks=True, ) ) valid_masks.append( _transform_valid_mask( channel.shape[1:3], matrix, cuda=cuda, ) ) matrices.append(matrix) # Merge each requested group into one output global-detection channel. crop_shape = _largest_centered_valid_crop(valid_masks) transformed, crop_shape, crop_bboxes = _center_crop_arrays( transformed, crop_shape, stacks=True, ) transformed = xp.stack(transformed, axis=0) new_channels = [] for group_index, (group, group_gain) in enumerate( zip(globdet_groups, globdet_channels_gains) ): merged = agg_func(transformed[group], axis=0, dtype=xp.float32) merged = merged / group_gain new_channels.append( _copy_to_output( merged, global_channels, group_index, stacks=True, name="global_channels", ) ) info = { "channels_pixels_nm": channels_pixels_nm, "ref_pix": ref_pix, "scales_x": scales_x, "scales_y": scales_y, "channels_x_shifts_nm": channels_x_shifts_nm, "channels_y_shifts_nm": channels_y_shifts_nm, "channels_rotations_deg": channels_rotations_deg, "channels_x_shears": channels_x_shears, "channels_y_shears": channels_y_shears, "globdet_groups": [list(group) for group in globdet_groups], "channels_gains": channels_gains, "globdet_channels_gains": globdet_channels_gains, "transform_matrices": matrices, "crop_shape": crop_shape, "crop_bboxes": crop_bboxes, } return new_channels, info
def _normalize_channels_pixels_nm(channels_pixels_nm, n_channels): """Normalize pixel sizes to one ``(py, px)`` tuple per channel.""" try: n_pixels = len(channels_pixels_nm) except TypeError: channels_pixels_nm = [ (channels_pixels_nm, channels_pixels_nm) for _ in range(n_channels) ] else: try: len(channels_pixels_nm[0]) except TypeError: if n_pixels == 2: channels_pixels_nm = [ channels_pixels_nm for _ in range(n_channels) ] elif n_pixels == n_channels: channels_pixels_nm = [ (pix, pix) for pix in channels_pixels_nm ] else: raise ValueError( "channels_pixels_nm does not have the same length as channels" ) else: if n_pixels != n_channels: raise ValueError( "channels_pixels_nm does not have the same length as channels" ) return channels_pixels_nm def _normalize_channels_parameter(values, n_channels, name): """Normalize scalar/per-channel values to a per-channel sequence.""" try: if len(values) != n_channels: raise ValueError(f"{name} does not have the same length as channels") except TypeError: values = [values for _ in range(n_channels)] return values def _normalize_globdet_groups(globdet_groups, n_channels): """Normalize group definitions into 0-based channel index groups.""" if globdet_groups is None: return [list(range(n_channels))] normalized = [] for group in globdet_groups: if len(group) == 0: raise ValueError("globdet_groups cannot contain empty groups") group_indices = [int(index) for index in group] if min(group_indices) < 0 or max(group_indices) >= n_channels: raise ValueError("globdet_groups contains channel indices out of bounds") normalized.append(group_indices) return normalized def _transform_valid_mask(shape, matrix, cuda=False): """Transform a mask that tracks pixels not filled by affine borders.""" xp = get_xp(cuda) mask = xp.ones(shape, dtype=xp.float32) mask = img_transform( mask, matrix=matrix, cuda=cuda, parallel=False, order=0, cval=0.0, ) return mask > 0.5 def _largest_centered_valid_crop(valid_masks): """Find the largest centered crop valid in every mask.""" base_shape = np.asarray( ( min([mask.shape[0] for mask in valid_masks]), min([mask.shape[1] for mask in valid_masks]), ), dtype=int, ) lo, hi = 1, int(base_shape.min()) best_shape = None while lo <= hi: mid = (lo + hi) // 2 scale = mid / base_shape.min() crop_shape = np.maximum(1, np.floor(base_shape * scale).astype(int)) if _valid_centered_crop(valid_masks, crop_shape): best_shape = tuple(int(size) for size in crop_shape) lo = mid + 1 else: hi = mid - 1 if best_shape is None: raise ValueError("No centered crop without transformed border pixels found") return best_shape def _valid_centered_crop(valid_masks, crop_shape): """Return whether the centered crop is fully valid in every mask.""" for mask in valid_masks: y0 = (mask.shape[0] - crop_shape[0]) // 2 x0 = (mask.shape[1] - crop_shape[1]) // 2 y1 = y0 + crop_shape[0] x1 = x0 + crop_shape[1] valid = mask[y0:y1, x0:x1].all() if hasattr(valid, "item"): valid = valid.item() if not valid: return False return True def _center_crop_arrays(arrays, crop_shape, stacks=False): """Center-crop arrays to a shared spatial shape.""" spatial_axis = int(stacks) cropped = [] crop_bboxes = [] for array in arrays: y0 = (array.shape[spatial_axis] - crop_shape[0]) // 2 x0 = (array.shape[spatial_axis + 1] - crop_shape[1]) // 2 y1 = y0 + crop_shape[0] x1 = x0 + crop_shape[1] slices = [slice(None) for _ in range(array.ndim)] slices[spatial_axis] = slice(y0, y1) slices[spatial_axis + 1] = slice(x0, x1) cropped.append(array[tuple(slices)]) crop_bboxes.append((x0, y0, x1, y1)) return cropped, crop_shape, crop_bboxes def _copy_to_output(array, outputs, index, stacks=False, name="outputs"): """Copy an array into a reusable output buffer when one is available.""" if outputs is None: return array if len(outputs) <= index: raise ValueError(f"{name} does not have enough output arrays") out = _output_view(outputs[index], array.shape, stacks=stacks) if out is None: return array out[...] = array return out def _output_view(output, shape, stacks=False): """Return a compatible view into an output buffer, if possible.""" if output.shape == shape: return output if output.ndim != len(shape): return None if any(out_size < size for out_size, size in zip(output.shape, shape)): return None slices = [] for axis, (out_size, size) in enumerate(zip(output.shape, shape)): start = 0 if stacks and axis == 0 else (out_size - size) // 2 slices.append(slice(start, start + size)) return output[tuple(slices)]