Source code for smlmlp.modules.block_LP._functions.globlocalization.globloc_fit

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet



from smlmlp import block
from funclp import LM, MLE, LSE, Poisson, Normal, Gaussian2D, IsoGaussian, Spline2D, JointFunction, JointChannel
from arrlp import get_xp, nb_threads, coordinates
import numpy as np

SIGMA = 0.21 * 670 / 1.5



[docs] @block() def globloc_fit( crops, X0, Y0, /, channels_models, channels_fit_inits, *, optimizer="lm", estimator="mle", distribution="poisson", channels_pixels_nm=1.0, channels_gains=1.0, channels_QE=1.0, cuda=False, parallel=False, ): """Fit global localizations from channel crops using joint fitting. Uses :class:`funclp.JointFunction` to fit all crops simultaneously, sharing position parameters across channels. Parameters ---------- crops : sequence of ndarray Crop stacks to fit, one per channel, shaped ``(N, Y, X)``. X0 : sequence of ndarray Crop x origins in pixels. Y0 : sequence of ndarray Crop y origins in pixels. channels_models : sequence of str Model per channel, one of ``"gauss"``, ``"isogauss"``, ``"spline"``. channels_fit_inits : sequence of dict Initial fit parameters per channel. optimizer : str, optional Optimizer key. estimator : str, optional Estimator key. distribution : str, optional Distribution key used by the estimator. channels_pixels_nm : float or sequence, optional Pixel size specification per channel. channels_gains : float or sequence, optional Gain value(s) used for fitted amplitudes. channels_QE : float or sequence, optional Quantum efficiency value(s) used for fitted amplitudes. cuda : bool, optional Whether to use CUDA execution. parallel : bool, optional Whether to use parallel execution. Returns ------- tuple A tuple ``(mux, muy, info)`` where: - ``mux`` is the concatenated x localization array in nanometers, - ``muy`` is the concatenated y localization array in nanometers, - ``info`` is a dictionary with fitted parameter arrays. Examples -------- >>> import numpy as np >>> crops = [np.random.rand(2, 7, 7).astype(np.float32)] >>> x0 = [np.array([10, 20], dtype=np.float32)] >>> y0 = [np.array([30, 40], dtype=np.float32)] >>> models = ["gauss"] >>> inits = [{"sigx": 90.0, "sigy": 90.0, "theta": 0.0, "theta_fit": False}] >>> mux, muy, info = globloc_fit( ... crops, x0, y0, ... channels_models=models, ... channels_fit_inits=inits, ... channels_pixels_nm=[(100.0, 100.0)], ... ) >>> mux.shape == muy.shape True >>> crops = [np.random.rand(2, 7, 7).astype(np.float32)] >>> x0 = [np.array([10, 20], dtype=np.float32)] >>> y0 = [np.array([30, 40], dtype=np.float32)] >>> models = ["isogauss"] >>> inits = [{"sig": 90.0}] >>> mux, muy, info = globloc_fit( ... crops, x0, y0, ... channels_models=models, ... channels_fit_inits=inits, ... channels_pixels_nm=[(100.0, 100.0)], ... ) >>> info['sigma'].ndim 1 """ n_channels = len(crops) channels_pixels_nm = _normalize_channels_pixels_nm( channels_pixels_nm, n_channels, ) channels_gains = _normalize_channels_parameter(channels_gains, n_channels) channels_QE = _normalize_channels_parameter(channels_QE, n_channels) optimizer_cls = _resolve_optimizer(optimizer) distribution = _resolve_distribution(distribution) estimator = _resolve_estimator(estimator, distribution) if len(channels_models) != n_channels: raise ValueError("channels_models must have same length as crops") if len(channels_fit_inits) != n_channels: raise ValueError("channels_fit_inits must have same length as crops") xp = get_xp(cuda) functions = [] function_data = [] all_crop_data = [] all_xy_data = [] for ch_idx, (crop, x0, y0, pixel, model_name, fit_init) in enumerate( zip(crops, X0, Y0, channels_pixels_nm, channels_models, channels_fit_inits) ): crop = xp.asarray(crop) _, height, width = crop.shape yy, xx = coordinates(shape=(height, width), pixel=pixel, cuda=cuda) x0 = xp.asarray(x0) * pixel[1] y0 = xp.asarray(y0) * pixel[0] mux = xp.full_like(x0, fill_value=(width - 1) / 2 * pixel[1]) muy = xp.full_like(y0, fill_value=(height - 1) / 2 * pixel[0]) amp = xp.max(crop, axis=(1, 2)) offset = xp.min(crop, axis=(1, 2)) model_name = model_name.lower() if model_name == "gauss": function = Gaussian2D( mux=mux, muy=muy, amp=amp, offset=offset, cuda=cuda, sigx=fit_init.get("sigx", SIGMA), sigy=fit_init.get("sigy", SIGMA), theta=fit_init.get("theta", 0.0), pixx=pixel[1], pixy=pixel[0], theta_fit=fit_init.get("theta_fit", False), ) elif model_name == "isogauss": function = IsoGaussian( mux=mux, muy=muy, amp=amp, offset=offset, cuda=cuda, sig=fit_init.get("sig", SIGMA), pixx=pixel[1], pixy=pixel[0], ) elif model_name == "spline": function = Spline2D( mux=mux, muy=muy, amp=amp, offset=offset, cuda=cuda, tx=fit_init.get("tx"), ty=fit_init.get("ty"), coeffs=fit_init.get("coeffs"), pixx=pixel[1], pixy=pixel[0], ) else: raise ValueError(f"Unknown model: {model_name}") functions.append(function) function_data.append({"x0": x0, "y0": y0, "pixel": pixel}) all_crop_data.append(crop) all_xy_data.append({"x": xx, "y": yy}) if n_channels == 1: function = functions[0] data = function_data[0] fit = optimizer_cls(function, estimator) if cuda: fit(all_crop_data[0], all_xy_data[0]["x"], all_xy_data[0]["y"]) else: with nb_threads(parallel): fit(all_crop_data[0], all_xy_data[0]["x"], all_xy_data[0]["y"]) else: prefixes = [f"ch{i}" for i in range(n_channels)] joint_channels = [ JointChannel(func, prefix=prefix) for func, prefix in zip(functions, prefixes) ] shared_vars = { "x": ["x"] * n_channels, "y": ["y"] * n_channels, } shared_params = { "mux": ["mux"] * n_channels, "muy": ["muy"] * n_channels, } joint_function = JointFunction( joint_channels, shared_variables=shared_vars, shared_parameters=shared_params, ) fit = optimizer_cls(joint_function, estimator) with nb_threads(parallel): fit(all_crop_data, all_xy_data) mux_all = [] muy_all = [] amp_all = [] offset_all = [] sigmax_all = [] sigmay_all = [] sigma_all = [] for ch_idx, (function, x0, y0, pixel, gain, qe, model_name) in enumerate( zip(functions, X0, Y0, channels_pixels_nm, channels_gains, channels_QE, channels_models) ): x0 = xp.asarray(x0) * pixel[1] y0 = xp.asarray(y0) * pixel[0] mux = function.mux + x0 muy = function.muy + y0 amp = function.amp / qe * gain offset = function.offset / qe * gain if cuda: mux = xp.asnumpy(mux) muy = xp.asnumpy(muy) amp = xp.asnumpy(amp) offset = xp.asnumpy(offset) mux_all.append(mux) muy_all.append(muy) amp_all.append(amp) offset_all.append(offset) model_name = model_name.lower() if model_name == "gauss": sigx = function.sigx sigy = function.sigy if cuda: sigx = xp.asnumpy(sigx) sigy = xp.asnumpy(sigy) sigmax_all.append(sigx) sigmay_all.append(sigy) elif model_name == "isogauss": sig = function.sig if cuda: sig = xp.asnumpy(sig) sigma_all.append(sig) elif model_name == "spline": pass info = { "amp": np.hstack(amp_all), "offset": np.hstack(offset_all), } if sigmax_all: info["sigmax"] = np.hstack(sigmax_all) info["sigmay"] = np.hstack(sigmay_all) if sigma_all: info["sigma"] = np.hstack(sigma_all) return np.hstack(mux_all), np.hstack(muy_all), info
def _normalize_channels_pixels_nm(channels_pixels_nm, n_channels): """Normalize pixel sizes to one ``(py, px)`` tuple per channel.""" try: if len(channels_pixels_nm) != n_channels: if len(channels_pixels_nm) == 2: channels_pixels_nm = [channels_pixels_nm for _ in range(n_channels)] else: raise ValueError( "channel_mean_radius_pix does not have the same length as channels" ) except TypeError: channels_pixels_nm = [ (channels_pixels_nm, channels_pixels_nm) for _ in range(n_channels) ] return channels_pixels_nm def _normalize_channels_parameter(values, n_channels): """Normalize scalar/per-channel values to a per-channel sequence.""" try: if len(values) != n_channels: raise ValueError( "channel_mean_radius_pix does not have the same length as channels" ) except TypeError: values = [values for _ in range(n_channels)] return values def _resolve_optimizer(optimizer): """Resolve optimizer key to optimizer class.""" match optimizer.lower(): case "lm": return LM case _: raise SyntaxError(f"Optimizer {optimizer} is not recognized") def _resolve_distribution(distribution): """Resolve distribution key to instantiated distribution.""" match distribution.lower(): case "normal": return Normal() case "poisson": return Poisson() case _: raise SyntaxError(f"Distribution {distribution} is not recognized") def _resolve_estimator(estimator, distribution): """Resolve estimator key to instantiated estimator.""" match estimator.lower(): case "mle": return MLE(distribution) case "lse": return LSE() case _: raise SyntaxError(f"Estimator {estimator} is not recognized")