Source code for smlmlp.modules.block_LP._functions.localization.locs_individual_splinefit

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet



from smlmlp import block, locs_individual_barycenter
from funclp import LM, MLE, LSE, Poisson, Normal, Spline3D
from arrlp import get_xp, nb_threads, coordinates
from ._channel_values import split_channel_origins, stack_channel_values

SIGMA = 0.21 * 670 / 1.5



[docs] @block() def locs_individual_splinefit( crops, X0, Y0, /, ch=None, *, optimizer="lm", estimator="mle", distribution="poisson", channels_pixels_nm=1.0, channels_gains=1.0, channels_QE=1.0, cuda=False, parallel=False, channels_psf_3d_xtangents=None, channels_psf_3d_ytangents=None, channels_psf_3d_ztangents=None, channels_psf_3d_spline_coeffs=None, ): """ Fit each crop independently with a 3D spline PSF model. The function loops through channels, initializes a :class:`funclp.Spline3D` model per event, runs the selected optimizer/estimator combination, and returns localized coordinates with fitted photometric values. Parameters ---------- crops : sequence of array-like Sequence of crop stacks, one per channel, shaped ``(N, Y, X)``. X0 : array-like Detection-aligned 1D vector of x-origin pixel indices. Y0 : array-like Detection-aligned 1D vector of y-origin pixel indices. ch : array-like or None, optional One-based channel index for each detection. Required when ``crops`` has several channels. optimizer : str, optional Optimizer key. estimator : str, optional Estimator key. distribution : str, optional Distribution key used by the estimator. channels_pixels_nm : float or sequence, optional Pixel size specification per channel. channels_gains : float or sequence, optional Gain value(s) used to convert fitted amplitudes. channels_QE : float or sequence, optional Quantum efficiency value(s) used to convert fitted amplitudes. cuda : bool, optional Whether to run the fit on GPU. parallel : bool, optional Whether to enable CPU parallelization. channels_psf_3d_xtangents : sequence Spline x tangents, one set per channel. channels_psf_3d_ytangents : sequence Spline y tangents, one set per channel. channels_psf_3d_ztangents : sequence Spline z tangents, one set per channel. channels_psf_3d_spline_coeffs : sequence Spline coefficients, one set per channel. Returns ------- tuple A tuple ``(mux, muy, muz, info)`` where: - ``mux`` is the detection-aligned x localization array in nanometers, - ``muy`` is the detection-aligned y localization array in nanometers, - ``muz`` is the detection-aligned z localization array, - ``info`` is a dictionary with fitted parameter arrays. ``info`` contains: ``'amp'`` Detection-aligned converted amplitudes. ``'offset'`` Detection-aligned converted offsets. Notes ----- 1. ``X0`` and ``Y0`` are split by ``ch`` so each origin vector follows the crop order inside its channel stack. 2. A local x/y grid and zero z grid are built for each channel, and spline models are initialized at the crop center with zero z. 3. The optimizer updates local spline parameters, local coordinates are shifted by crop origins, and all outputs are remapped to detection order. Examples -------- >>> import numpy as np >>> crops = [np.random.rand(2, 7, 7).astype(np.float32)] >>> x0 = np.array([10, 20], dtype=np.float32) >>> y0 = np.array([30, 40], dtype=np.float32) >>> tx = [np.linspace(-1.0, 1.0, 5, dtype=np.float32)] >>> ty = [np.linspace(-1.0, 1.0, 5, dtype=np.float32)] >>> tz = [np.linspace(-0.5, 0.5, 5, dtype=np.float32)] >>> coeffs = [np.ones((4, 4, 4), dtype=np.float32)] >>> mux, muy, muz, info = locs_individual_splinefit( ... crops, ... x0, ... y0, ... channels_pixels_nm=[(100.0, 100.0)], ... channels_psf_3d_xtangents=tx, ... channels_psf_3d_ytangents=ty, ... channels_psf_3d_ztangents=tz, ... channels_psf_3d_spline_coeffs=coeffs, ... ) >>> mux.shape == muy.shape == muz.shape True >>> sorted(info) ['amp', 'offset'] >>> mux, muy, muz, info = locs_individual_splinefit( ... crops, ... x0, ... y0, ... channels_pixels_nm=[(100.0, 120.0)], ... channels_psf_3d_xtangents=tx, ... channels_psf_3d_ytangents=ty, ... channels_psf_3d_ztangents=tz, ... channels_psf_3d_spline_coeffs=coeffs, ... ) >>> info['amp'].ndim 1 """ n_channels = len(crops) X0_input, Y0_input = X0, Y0 channels_pixels_nm = _normalize_channels_pixels_nm(channels_pixels_nm, n_channels) channels_gains = _normalize_channels_parameter(channels_gains, n_channels) channels_QE = _normalize_channels_parameter(channels_QE, n_channels) optimizer = _resolve_optimizer(optimizer) distribution = _resolve_distribution(distribution) estimator = _resolve_estimator(estimator, distribution) if channels_psf_3d_xtangents is None: raise SyntaxError("channels_psf_3d_xtangents must be specified as a kwarg") if len(channels_psf_3d_xtangents) != n_channels: raise ValueError("channels_psf_3d_xtangents does not have the same length as crops") if channels_psf_3d_ytangents is None: raise SyntaxError("channels_psf_3d_ytangents must be specified as a kwarg") if len(channels_psf_3d_ytangents) != n_channels: raise ValueError("channels_psf_3d_ytangents does not have the same length as crops") if channels_psf_3d_ztangents is None: raise SyntaxError("channels_psf_3d_ztangents must be specified as a kwarg") if len(channels_psf_3d_ztangents) != n_channels: raise ValueError("channels_psf_3d_ztangents does not have the same length as crops") if channels_psf_3d_spline_coeffs is None: raise SyntaxError("channels_psf_3d_spline_coeffs must be specified as a kwarg") if len(channels_psf_3d_spline_coeffs) != n_channels: raise ValueError( "channels_psf_3d_spline_coeffs does not have the same length as crops" ) X0, Y0, positions = split_channel_origins(crops, X0_input, Y0_input, ch, cuda=cuda) bary_x, bary_y, _ = locs_individual_barycenter( crops, X0_input, Y0_input, ch=ch, channels_pixels_nm=channels_pixels_nm, cuda=cuda, parallel=parallel, ) bary_x, bary_y, _ = split_channel_origins(crops, bary_x, bary_y, ch, cuda=cuda) fit_kwargs = [ dict( tx=tx, ty=ty, tz=tz, coeffs=coeffs, ) for tx, ty, tz, coeffs in zip( channels_psf_3d_xtangents, channels_psf_3d_ytangents, channels_psf_3d_ztangents, channels_psf_3d_spline_coeffs, ) ] xp = get_xp(cuda) mux_all = [] muy_all = [] muz_all = [] amp_all = [] offset_all = [] converged_all = [] for crop, x0, y0, bary_x_ch, bary_y_ch, pixel, gain, qe, function_kw in zip( crops, X0, Y0, bary_x, bary_y, channels_pixels_nm, channels_gains, channels_QE, fit_kwargs, ): crop = xp.asarray(crop) _, height, width = crop.shape yy, xx = coordinates(shape=(height, width), center=False, pixel=pixel, cuda=cuda) zz = xp.zeros_like(xx) x0 = xp.asarray(x0) * pixel[1] y0 = xp.asarray(y0) * pixel[0] if len(crop) == 0: empty = xp.empty(0, dtype=xp.float32) mux_all.append(empty) muy_all.append(empty) muz_all.append(empty) amp_all.append(empty) offset_all.append(empty) converged_all.append(xp.empty(0, dtype=xp.int8)) continue center_x = (width - 1) / 2 * pixel[1] center_y = (height - 1) / 2 * pixel[0] bary_x_ch = xp.asarray(bary_x_ch) bary_y_ch = xp.asarray(bary_y_ch) mux_center = bary_x_ch - x0 muy_center = bary_y_ch - y0 mux_center = xp.where(xp.isfinite(mux_center), mux_center, center_x) muy_center = xp.where(xp.isfinite(muy_center), muy_center, center_y) mux_min = xp.zeros_like(mux_center, dtype=xp.float32) mux_max = xp.full_like(mux_center, fill_value=(width - 1) * pixel[1], dtype=xp.float32) muy_min = xp.zeros_like(muy_center, dtype=xp.float32) muy_max = xp.full_like(muy_center, fill_value=(height - 1) * pixel[0], dtype=xp.float32) mux = mux_center muy = muy_center mux = xp.clip(mux, mux_min, mux_max) muy = xp.clip(muy, muy_min, muy_max) muz = xp.zeros_like(x0) offset = xp.min(crop, axis=(1, 2)) amp = xp.max(crop, axis=(1, 2)) - offset function = Spline3D( mux=mux, muy=muy, muz=muz, amp=amp, offset=offset, cuda=cuda, **function_kw, ) function.amp_min = 0.0 function.offset_min = 0.0 function.mux_min = mux_min function.mux_max = mux_max function.muy_min = muy_min function.muy_max = muy_max fit = optimizer(function, estimator) if cuda: fit(crop, xx, yy, zz) else: with nb_threads(parallel): fit(crop, xx, yy, zz) converged = getattr(fit, "converged", xp.zeros(len(crop), dtype=xp.int8)) clamped = xp.isclose(function.mux, mux_min) | xp.isclose(function.mux, mux_max) clamped |= xp.isclose(function.muy, muy_min) | xp.isclose(function.muy, muy_max) converged = xp.where(clamped, xp.full_like(converged, -4), converged) mux, muy, muz = function.mux, function.muy, function.muz mux += x0 muy += y0 amp = function.amp / qe * gain offset = function.offset / qe * gain if cuda: mux = xp.asnumpy(mux) muy = xp.asnumpy(muy) muz = xp.asnumpy(muz) amp = xp.asnumpy(amp) offset = xp.asnumpy(offset) converged = xp.asnumpy(converged) mux_all.append(mux) muy_all.append(muy) muz_all.append(muz) amp_all.append(amp) offset_all.append(offset) converged_all.append(converged) info = { "amp": stack_channel_values(amp_all, positions), "offset": stack_channel_values(offset_all, positions), "converged": stack_channel_values(converged_all, positions), } return stack_channel_values(mux_all, positions), stack_channel_values(muy_all, positions), stack_channel_values(muz_all, positions), info
def _normalize_channels_pixels_nm(channels_pixels_nm, n_channels): """Normalize pixel sizes to one ``(py, px)`` tuple per channel.""" try: if len(channels_pixels_nm) != n_channels: if len(channels_pixels_nm) == 2: channels_pixels_nm = [channels_pixels_nm for _ in range(n_channels)] else: raise ValueError( "channel_mean_radius_pix does not have the same length as channels" ) except TypeError: channels_pixels_nm = [ (channels_pixels_nm, channels_pixels_nm) for _ in range(n_channels) ] return channels_pixels_nm def _normalize_channels_parameter(values, n_channels): """Normalize scalar/per-channel values to a per-channel sequence.""" try: if len(values) != n_channels: raise ValueError( "channel_mean_radius_pix does not have the same length as channels" ) except TypeError: values = [values for _ in range(n_channels)] return values def _resolve_optimizer(optimizer): """Resolve optimizer key to optimizer class.""" match optimizer.lower(): case "lm": return LM case _: raise SyntaxError(f"Optimizer {optimizer} is not recognized") def _resolve_distribution(distribution): """Resolve distribution key to instantiated distribution.""" match distribution.lower(): case "normal": return Normal() case "poisson": return Poisson() case _: raise SyntaxError(f"Distribution {distribution} is not recognized") def _resolve_estimator(estimator, distribution): """Resolve estimator key to instantiated estimator.""" match estimator.lower(): case "mle": return MLE(distribution) case "lse": return LSE() case _: raise SyntaxError(f"Estimator {estimator} is not recognized")