#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author : Lancelot PINCET
# GitHub : https://github.com/LancelotPincet
import numpy as np
from arrlp import coordinates, get_xp, nb_threads
from funclp import Gaussian2D, IsoGaussian, LM, LSE, MLE, Normal, Poisson, Spline3D
from smlmlp import block, locs_individual_barycenter
from ._channel_values import split_channel_origins, stack_channel_values
SIGMA = 0.21 * 670 / 1.5
FIT_MODELS = ("isogauss", "gauss", "spline")
[docs]
@block()
def locs_individual_fit(
crops, X0, Y0, /, ch=None, *,
channels_fit_models="isogauss", optimizer="lm", estimator="mle", distribution="poisson",
channels_pixels_nm=1.0, channels_gains=1.0, channels_QE=1.0, cuda=False, parallel=False,
channels_psf_sigmas_nm=SIGMA, channels_psf_xsigmas_nm=SIGMA, channels_psf_ysigmas_nm=SIGMA,
channels_psf_thetas_deg=0.0, channels_fit_thetas=False,
channels_psf_3d_xtangents=None, channels_psf_3d_ytangents=None, channels_psf_3d_ztangents=None,
channels_psf_3d_spline_coeffs=None,
):
"""
Fit each crop stack with a channel-specific localization model.
Parameters
----------
crops : sequence of array-like
Crop stacks to fit, one stack per channel. Each stack must be shaped
``(N, Y, X)`` where ``N`` is the number of events in that channel.
X0 : array-like
Detection-aligned 1D vector of crop x-origin pixel indices.
Y0 : array-like
Detection-aligned 1D vector of crop y-origin pixel indices.
ch : array-like or None, optional
One-based channel index for each detection. Required when ``crops`` has
several channels.
channels_fit_models : str or sequence of str, default="isogauss"
Model used for each channel crop stack. Accepted values are
``"isogauss"``, ``"gauss"``, and ``"spline"``.
optimizer : {"lm"}, default="lm"
Optimizer used to fit each model.
estimator : {"mle", "lse"}, default="mle"
Estimator used by the optimizer.
distribution : {"poisson", "normal"}, default="poisson"
Noise distribution used by maximum-likelihood estimators.
channels_pixels_nm : float, tuple, or sequence, default=1.0
Pixel size specification. A scalar is used for both axes and all
channels, a ``(py, px)`` tuple is broadcast to all channels, and a
sequence provides one ``(py, px)`` pair per channel.
channels_gains : float or sequence, default=1.0
Gain values used to convert fitted amplitudes and offsets.
channels_QE : float or sequence, default=1.0
Quantum efficiencies used to convert fitted amplitudes and offsets.
cuda : bool, default=False
Whether to run fits on CUDA when supported.
parallel : bool, default=False
Whether to use threaded CPU execution.
channels_psf_sigmas_nm : float or sequence, default=SIGMA
Initial isotropic sigma values for ``"isogauss"`` channels.
channels_psf_xsigmas_nm : float or sequence, default=SIGMA
Initial x sigma values for ``"gauss"`` channels.
channels_psf_ysigmas_nm : float or sequence, default=SIGMA
Initial y sigma values for ``"gauss"`` channels.
channels_psf_thetas_deg : float or sequence, default=0.0
Initial theta values, in degrees, for ``"gauss"`` channels.
channels_fit_thetas : bool or sequence, default=False
Whether to fit theta for ``"gauss"`` channels.
channels_psf_3d_xtangents : sequence, optional
Spline x tangents for ``"spline"`` channels.
channels_psf_3d_ytangents : sequence, optional
Spline y tangents for ``"spline"`` channels.
channels_psf_3d_ztangents : sequence, optional
Spline z tangents for ``"spline"`` channels.
channels_psf_3d_spline_coeffs : sequence, optional
Spline coefficients for ``"spline"`` channels.
Returns
-------
tuple
A tuple ``(mux, muy, muz, info)`` where ``mux`` and ``muy`` are
detection-aligned fitted coordinates in nanometers and ``muz`` contains fitted
spline z coordinates or ``np.nan`` for 2D models. ``info`` contains
detection-aligned ``"amp"``, ``"offset"``, optional ``"sigma"``, optional
``"sigmax"``, optional ``"sigmay"`` arrays, plus a ``"models"`` list.
``"sigma"`` is ``None`` when no channel uses ``"isogauss"``;
``"sigmax"``/``"sigmay"`` are ``None`` when no channel uses ``"gauss"``.
Present sigma arrays match localization length and contain ``np.nan``
where a quantity does not apply to a given channel model.
Raises
------
ValueError
If per-channel inputs have incompatible lengths or a model name is not
one of ``"isogauss"``, ``"gauss"``, and ``"spline"``.
SyntaxError
If the optimizer, estimator, distribution, or required spline metadata
is missing or unsupported.
Notes
-----
1. Channel model names and per-channel parameters are normalized to match
the crop list length.
2. ``X0`` and ``Y0`` are split by ``ch`` so origins match each crop stack.
3. Each channel selects ``IsoGaussian``, ``Gaussian2D``, or ``Spline3D`` and
initializes local coordinates at the crop center.
4. The selected optimizer updates model parameters in local nanometer
coordinates before crop origins are added back.
5. Coordinates and fitted parameter arrays are remapped to detection order.
Examples
--------
>>> import numpy as np
>>> crops = [np.random.rand(2, 7, 7).astype(np.float32)]
>>> x0 = np.array([10, 20], dtype=np.float32)
>>> y0 = np.array([30, 40], dtype=np.float32)
>>> mux, muy, muz, info = locs_individual_fit(
... crops,
... x0,
... y0,
... channels_fit_models=["isogauss"],
... channels_pixels_nm=[(100.0, 100.0)],
... )
>>> mux.shape == muy.shape
True
>>> tx = [np.linspace(-300.0, 300.0, 8, dtype=np.float32)]
>>> ty = [np.linspace(-300.0, 300.0, 8, dtype=np.float32)]
>>> tz = [np.linspace(-300.0, 300.0, 8, dtype=np.float32)]
>>> coeffs = [np.ones((4, 4, 4), dtype=np.float32)]
>>> mux, muy, muz, info = locs_individual_fit(
... crops,
... x0,
... y0,
... channels_fit_models=["spline"],
... channels_pixels_nm=[(100.0, 100.0)],
... channels_psf_3d_xtangents=tx,
... channels_psf_3d_ytangents=ty,
... channels_psf_3d_ztangents=tz,
... channels_psf_3d_spline_coeffs=coeffs,
... )
>>> info["models"]
['spline']
"""
# Split origins by channel
n_channels = len(crops)
X0_input, Y0_input = X0, Y0
# Normalize per-channel parameters
channels_fit_models = _normalize_channels_fit_models(channels_fit_models, n_channels)
channels_pixels_nm = _normalize_channels_pixels_nm(channels_pixels_nm, n_channels)
channels_gains = _normalize_channels_parameter(channels_gains, n_channels)
channels_QE = _normalize_channels_parameter(channels_QE, n_channels)
channels_psf_sigmas_nm = _normalize_channels_parameter(channels_psf_sigmas_nm, n_channels)
channels_psf_xsigmas_nm = _normalize_channels_parameter(channels_psf_xsigmas_nm, n_channels)
channels_psf_ysigmas_nm = _normalize_channels_parameter(channels_psf_ysigmas_nm, n_channels)
channels_psf_thetas_deg = _normalize_channels_parameter(channels_psf_thetas_deg, n_channels)
channels_fit_thetas = _normalize_channels_parameter(channels_fit_thetas, n_channels)
# Normalize spline parameters
needs_spline = "spline" in channels_fit_models
channels_psf_3d_xtangents = _normalize_spline_parameter(
channels_psf_3d_xtangents, n_channels, "channels_psf_3d_xtangents", required=needs_spline,
)
channels_psf_3d_ytangents = _normalize_spline_parameter(
channels_psf_3d_ytangents, n_channels, "channels_psf_3d_ytangents", required=needs_spline,
)
channels_psf_3d_ztangents = _normalize_spline_parameter(
channels_psf_3d_ztangents, n_channels, "channels_psf_3d_ztangents", required=needs_spline,
)
channels_psf_3d_spline_coeffs = _normalize_spline_parameter(
channels_psf_3d_spline_coeffs, n_channels, "channels_psf_3d_spline_coeffs", required=needs_spline,
)
X0, Y0, positions = split_channel_origins(crops, X0_input, Y0_input, ch, cuda=cuda)
bary_x, bary_y, _ = locs_individual_barycenter(
crops,
X0_input,
Y0_input,
ch=ch,
channels_pixels_nm=channels_pixels_nm,
cuda=cuda,
parallel=parallel,
)
bary_x, bary_y, _ = split_channel_origins(crops, bary_x, bary_y, ch, cuda=cuda)
# Resolve optimizer and estimator
optimizer_cls = _resolve_optimizer(optimizer)
distribution = _resolve_distribution(distribution)
estimator = _resolve_estimator(estimator, distribution)
xp = get_xp(cuda)
# Initialize output containers
has_isogauss = "isogauss" in channels_fit_models
has_gauss = "gauss" in channels_fit_models
mux_all, muy_all, amp_all, offset_all = [], [], [], []
sigma_all, sigmax_all, sigmay_all, muz_all = [], [], [], []
converged_all = []
# Iterate over channels
for crop, x0, y0, bary_x_ch, bary_y_ch, pixel, gain, qe, model, sigma, sigx, sigy, theta, fit_theta, tx, ty, tz, coeffs in zip(
crops, X0, Y0, bary_x, bary_y, channels_pixels_nm, channels_gains, channels_QE, channels_fit_models,
channels_psf_sigmas_nm, channels_psf_xsigmas_nm, channels_psf_ysigmas_nm, channels_psf_thetas_deg,
channels_fit_thetas, channels_psf_3d_xtangents, channels_psf_3d_ytangents, channels_psf_3d_ztangents,
channels_psf_3d_spline_coeffs,
):
# Prepare input arrays and initial guesses
crop = xp.asarray(crop)
_, height, width = crop.shape
yy, xx = coordinates(shape=(height, width), center=False, pixel=pixel, cuda=cuda)
x0 = xp.asarray(x0) * pixel[1]
y0 = xp.asarray(y0) * pixel[0]
# Handle empty crops
if len(crop) == 0:
empty = xp.empty(0, dtype=xp.float32)
mux_all.append(empty)
muy_all.append(empty)
muz_all.append(empty)
amp_all.append(empty)
offset_all.append(empty)
sigma_all.append(empty)
sigmax_all.append(empty)
sigmay_all.append(empty)
converged_all.append(xp.empty(0, dtype=xp.int8))
continue
# Initialize fit parameters
center_x = (width - 1) / 2 * pixel[1]
center_y = (height - 1) / 2 * pixel[0]
bary_x_ch = xp.asarray(bary_x_ch)
bary_y_ch = xp.asarray(bary_y_ch)
mux_center = bary_x_ch - x0
muy_center = bary_y_ch - y0
mux_center = xp.where(xp.isfinite(mux_center), mux_center, center_x)
muy_center = xp.where(xp.isfinite(muy_center), muy_center, center_y)
mux_min = xp.zeros_like(mux_center, dtype=xp.float32)
mux_max = xp.full_like(mux_center, fill_value=(width - 1) * pixel[1], dtype=xp.float32)
muy_min = xp.zeros_like(muy_center, dtype=xp.float32)
muy_max = xp.full_like(muy_center, fill_value=(height - 1) * pixel[0], dtype=xp.float32)
mux = mux_center
muy = muy_center
mux = xp.clip(mux, mux_min, mux_max)
muy = xp.clip(muy, muy_min, muy_max)
offset = xp.min(crop, axis=(1, 2))
amp = xp.max(crop, axis=(1, 2)) - offset
# Create model and run optimizer
function = _make_model(
model=model, mux=mux, muy=muy, amp=amp, offset=offset,
pixel=pixel, sigma=sigma, sigx=sigx, sigy=sigy, theta=theta, fit_theta=fit_theta,
tx=tx, ty=ty, tz=tz, coeffs=coeffs, cuda=cuda,
)
function.amp_min = 0.0
function.offset_min = 0.0
function.mux_min = mux_min
function.mux_max = mux_max
function.muy_min = muy_min
function.muy_max = muy_max
if model == "isogauss":
function.sig_min = min(float(pixel[0]), float(pixel[1])) / 2.0
function.sig_max = min(float(height) * float(pixel[0]), float(width) * float(pixel[1])) / 2.0
if model == "gauss":
function.sigx_min = float(pixel[1]) / 2.0
function.sigx_max = float(width) * float(pixel[1]) / 2.0
function.sigy_min = float(pixel[0]) / 2.0
function.sigy_max = float(height) * float(pixel[0]) / 2.0
# Run optimizer and collect results
fit = optimizer_cls(function, estimator)
if model == "spline":
zz = xp.zeros_like(xx)
if cuda:
fit(crop, xx, yy, zz)
else:
with nb_threads(parallel):
fit(crop, xx, yy, zz)
muz = function.muz
else:
if cuda:
fit(crop, xx, yy)
else:
with nb_threads(parallel):
fit(crop, xx, yy)
muz = xp.full_like(x0, fill_value=np.nan, dtype=xp.float32)
converged = getattr(fit, "converged", xp.zeros(len(crop), dtype=xp.int8))
clamped = xp.isclose(function.mux, mux_min) | xp.isclose(function.mux, mux_max)
clamped |= xp.isclose(function.muy, muy_min) | xp.isclose(function.muy, muy_max)
if model == "isogauss":
clamped |= xp.isclose(function.sig, function.sig_min) | xp.isclose(function.sig, function.sig_max)
if model == "gauss":
clamped |= xp.isclose(function.sigx, function.sigx_min) | xp.isclose(function.sigx, function.sigx_max)
clamped |= xp.isclose(function.sigy, function.sigy_min) | xp.isclose(function.sigy, function.sigy_max)
converged = xp.where(clamped, xp.full_like(converged, -4), converged)
# Transform to global coordinates and apply gains
mux = function.mux + x0
muy = function.muy + y0
amp = function.amp / qe * gain
offset = function.offset / qe * gain
if cuda:
mux = xp.asnumpy(mux)
muy = xp.asnumpy(muy)
muz = xp.asnumpy(muz)
amp = xp.asnumpy(amp)
offset = xp.asnumpy(offset)
converged = xp.asnumpy(converged)
mux_all.append(mux)
muy_all.append(muy)
muz_all.append(muz)
amp_all.append(amp)
offset_all.append(offset)
converged_all.append(converged)
# Compute sigma values per model
nan_sigma = xp.full_like(x0, fill_value=np.nan, dtype=xp.float32)
if model == "isogauss":
sig = function.sig
sigx = function.sig
sigy = function.sig
elif model == "gauss":
sigx = function.sigx
sigy = function.sigy
sig = xp.sqrt(sigx * sigy)
else:
sig = nan_sigma
sigx = nan_sigma
sigy = nan_sigma
if cuda:
sig = xp.asnumpy(sig)
sigx = xp.asnumpy(sigx)
sigy = xp.asnumpy(sigy)
if has_isogauss:
sigma_all.append(sig)
if has_gauss:
sigmax_all.append(sigx)
sigmay_all.append(sigy)
# Pack results
info = {
"amp": stack_channel_values(amp_all, positions),
"offset": stack_channel_values(offset_all, positions),
"sigma": stack_channel_values(sigma_all, positions) if has_isogauss else None,
"sigmax": stack_channel_values(sigmax_all, positions) if has_gauss else None,
"sigmay": stack_channel_values(sigmay_all, positions) if has_gauss else None,
"converged": stack_channel_values(converged_all, positions),
"models": channels_fit_models,
}
return stack_channel_values(mux_all, positions), stack_channel_values(muy_all, positions), stack_channel_values(muz_all, positions), info
def _normalize_channels_fit_models(values, n_channels):
"""Normalize and validate per-channel model names."""
if isinstance(values, str):
values = [values for _ in range(n_channels)]
elif len(values) != n_channels:
raise ValueError("channels_fit_models must have same length as crops")
models = [str(value).lower() for value in values]
invalid = [model for model in models if model not in FIT_MODELS]
if invalid:
raise ValueError(
f"channels_fit_models contains unsupported values {invalid}; "
f"expected one of {FIT_MODELS}"
)
return models
def _normalize_channels_pixels_nm(channels_pixels_nm, n_channels):
"""Normalize pixel sizes to one ``(py, px)`` tuple per channel."""
try:
if len(channels_pixels_nm) != n_channels:
if len(channels_pixels_nm) == 2:
channels_pixels_nm = [channels_pixels_nm for _ in range(n_channels)]
else:
raise ValueError("channels_pixels_nm must have same length as crops")
except TypeError:
channels_pixels_nm = [
(channels_pixels_nm, channels_pixels_nm)
for _ in range(n_channels)
]
return channels_pixels_nm
def _normalize_channels_parameter(values, n_channels):
"""Normalize scalar/per-channel values to a per-channel sequence."""
if isinstance(values, str):
return [values for _ in range(n_channels)]
try:
if len(values) != n_channels:
raise ValueError("parameter must have same length as crops")
except TypeError:
values = [values for _ in range(n_channels)]
return values
def _normalize_spline_parameter(values, n_channels, name, required):
"""Normalize spline metadata to one object per channel."""
if values is None:
if required:
raise SyntaxError(f"{name} must be specified for spline channels")
return [None for _ in range(n_channels)]
if isinstance(values, np.ndarray):
if n_channels != 1:
raise ValueError(f"{name} must have same length as crops")
return [values]
if len(values) != n_channels:
raise ValueError(f"{name} must have same length as crops")
return values
def _resolve_optimizer(optimizer):
"""Resolve optimizer key to optimizer class."""
match optimizer.lower():
case "lm":
return LM
case _:
raise SyntaxError(f"Optimizer {optimizer} is not recognized")
def _resolve_distribution(distribution):
"""Resolve distribution key to instantiated distribution."""
match distribution.lower():
case "normal":
return Normal()
case "poisson":
return Poisson()
case _:
raise SyntaxError(f"Distribution {distribution} is not recognized")
def _resolve_estimator(estimator, distribution):
"""Resolve estimator key to instantiated estimator."""
match estimator.lower():
case "mle":
return MLE(distribution)
case "lse":
return LSE()
case _:
raise SyntaxError(f"Estimator {estimator} is not recognized")
def _make_model(*, model, mux, muy, amp, offset, pixel, sigma, sigx, sigy, theta, fit_theta, tx, ty, tz, coeffs, cuda):
"""Instantiate the fit model selected for one channel."""
if model == "isogauss":
return IsoGaussian(mux=mux, muy=muy, amp=amp, offset=offset, sig=sigma, pixx=pixel[1], pixy=pixel[0], cuda=cuda)
if model == "gauss":
return Gaussian2D(mux=mux, muy=muy, amp=amp, offset=offset, sigx=sigx, sigy=sigy, theta=theta, theta_fit=fit_theta, pixx=pixel[1], pixy=pixel[0], cuda=cuda)
if model == "spline":
return Spline3D(mux=mux, muy=muy, muz=mux * 0, amp=amp, offset=offset, tx=tx, ty=ty, tz=tz, coeffs=coeffs, cuda=cuda)
raise ValueError(f"Unknown model: {model}")
if __name__ == "__main__":
from corelp import test
test(__file__)