Source code for smlmlp.modules.block_LP._functions.registration.registrate_pcc_shift

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet



from smlmlp import block
from arrlp import img_crosscorr, get_xp
from funclp import Gaussian
from scipy.optimize import curve_fit
import numpy as np



[docs] @block() def registrate_pcc_shift( optimized, /, ref_pix=1.0, fit_window=11, *, cuda=False, parallel=False, ): """ Estimate redundant pairwise shifts from phase cross-correlation images. This function computes the phase cross-correlation between every pair of optimized channels and estimates a subpixel shift for each frame from the cross-correlation peak. Parameters ---------- optimized : sequence of ndarray Sequence of optimized image, one per channel. ref_pix : float or tuple of float, optional Reference pixel size used to convert shifts to physical units. If a scalar is provided, it is applied to both y and x as ``(ref_pix, ref_pix)``. fit_window : int, optional Half-size of the square crop used for Gaussian peak fitting around the integer PCC maximum. The fitted crop side is ``2*fit_window + 1``. Can be passed positionally or by keyword. cuda : bool, optional Whether to enable CUDA processing. parallel : bool, optional Whether to enable parallel processing. Returns ------- tuple A tuple ``(CC, shiftx, shifty, info)`` where: - ``CC`` is the list of phase cross-correlation stacks for all channel pairs, - ``shiftx`` is the list of per-frame x shifts for all channel pairs, - ``shifty`` is the list of per-frame y shifts for all channel pairs, - ``info`` is a dictionary containing reusable intermediate results. The dictionary contains the following keys: ``'ref_pix'`` Reference pixel size used for shift conversion. ``'pairs'`` List of channel index pairs ``(i, j)`` corresponding to the cross-correlation and shift outputs. Notes ----- For each channel pair, the peak of the phase cross-correlation is detected on each frame. A local Gaussian fit is then performed on x/y profiles of a square crop around that maximum to estimate a robust subpixel offset. Returned shifts follow ``registrate_transform`` convention, i.e. they are corrective shifts to apply to channel ``j`` in pair ``(i, j)``. Examples -------- >>> import numpy as np >>> ch1 = np.random.rand(16, 16).astype(np.float32) >>> ch2 = np.random.rand(16, 16).astype(np.float32) >>> CC, shiftx, shifty, info = registrate_pcc_shift([ch1, ch2]) >>> len(CC) 1 >>> info["pairs"] [(0, 1)] >>> ch3 = np.random.rand(16, 16).astype(np.float32) >>> CC, shiftx, shifty, info = registrate_pcc_shift( ... [ch1, ch2, ch3], ... ref_pix=(100.0, 120.0), ... ) >>> len(info["pairs"]) 3 """ # Select the array backend matching the requested execution mode. xp = get_xp(cuda) # Normalize the reference pixel size to a (y, x) pair. try: if len(ref_pix) != 2: raise ValueError("ref_pix does not have 2 values (y, x)") except TypeError: ref_pix = (ref_pix, ref_pix) fit_window = int(max(1, fit_window)) # Compute the phase cross-correlation for each channel pair and estimate # the corresponding per-frame shifts. CC = [] shiftx = [] shifty = [] pairs = [] for i in range(len(optimized)): for j in range(i + 1, len(optimized)): cc = img_crosscorr( optimized[i], optimized[j], phase=True, cuda=cuda, parallel=parallel, ) if cuda: cc = xp.asnumpy(cc) dx, dy = subpixel_peak_stack(cc, ref_pix=ref_pix, fit_window=fit_window) CC.append(cc) shiftx.append(dx) shifty.append(dy) pairs.append((i, j)) info = { "ref_pix": ref_pix, "fit_window": fit_window, "pairs": pairs, "CC": CC, } return shiftx, shifty, info
def subpixel_peak_stack(c, ref_pix=(1.0, 1.0), fit_window=11): """Estimate one subpixel peak position per frame from a CC stack.""" ny, nx = c.shape iy, ix = np.unravel_index(int(np.argmax(c)), c.shape) dy_sub, dx_sub = _subpixel_peak_gaussian(c, iy, ix, fit_window=fit_window) # Return corrective shifts matching registrate_transform convention. dy = -((iy - ny // 2) + dy_sub) * ref_pix[0] dx = -((ix - nx // 2) + dx_sub) * ref_pix[1] return dx, dy def _subpixel_peak_gaussian(c, iy, ix, fit_window=11): """Estimate a subpixel PCC peak with a local 2D Gaussian fit.""" half = int(max(1, fit_window)) ny, nx = c.shape y0 = max(iy - half, 0) y1 = min(iy + half + 1, ny) x0 = max(ix - half, 0) x1 = min(ix + half + 1, nx) win = np.asarray(c[y0:y1, x0:x1], dtype=np.float64) if win.shape[0] < 3 or win.shape[1] < 3: return 0.0, 0.0 mask = np.isfinite(win) if not np.any(mask): return 0.0, 0.0 # Exclude the central horizontal/vertical axes during Gaussian fitting to # reduce FFT line artifacts, as done in blink_spatial_psf. cy = int(np.clip(iy - y0, 0, win.shape[0] - 1)) cx = int(np.clip(ix - x0, 0, win.shape[1] - 1)) fit_mask = mask.copy() fit_mask[cy, :] = False fit_mask[:, cx] = False if np.count_nonzero(fit_mask) < 3: fit_mask = mask fit_win = np.where(fit_mask, win, np.nan) x = np.arange(x0, x1, dtype=np.float64) - float(ix) y = np.arange(y0, y1, dtype=np.float64) - float(iy) profile_x = _nanmean_no_warn(fit_win, axis=0) profile_y = _nanmean_no_warn(fit_win, axis=1) mux = _fit_profile_gaussian(x, profile_x, fit_window=half) muy = _fit_profile_gaussian(y, profile_y, fit_window=half) return muy, mux def _fit_profile_gaussian(coords, values, fit_window): """Fit a 1D Gaussian profile and return the fitted center.""" values = np.asarray(values, dtype=np.float64) coords = np.asarray(coords, dtype=np.float64) mask = np.isfinite(values) & np.isfinite(coords) if np.count_nonzero(mask) < 3: return 0.0 x = coords[mask] y = values[mask] ymin = float(np.min(y)) ymax = float(np.max(y)) amp0 = max(ymax - ymin, 1e-6) sig0 = max((x.max() - x.min()) / 4.0, 0.75) gaus = Gaussian(pix=1.0, sig=sig0) func2fit = lambda xi, mu, sig, amp, offset: gaus( xi, mu=mu, sig=sig, amp=amp, offset=offset, ) lower = [float(x.min()), 0.5, 0.0, ymin - amp0] upper = [float(x.max()), float(max(3.0, x.max() - x.min())), 2.0 * amp0, ymax + amp0] p0 = [0.0, sig0, amp0, ymin] try: popt, _ = curve_fit( func2fit, x, y, p0=p0, bounds=(lower, upper), maxfev=4000, ) mu = float(popt[0]) except Exception: return 0.0 if not np.isfinite(mu): return 0.0 lo = max(float(x.min()), -float(fit_window)) hi = min(float(x.max()), float(fit_window)) return float(np.clip(mu, lo, hi)) def _nanmean_no_warn(array, axis): """Compute nanmean along an axis without RuntimeWarning on empty slices.""" mask = np.isfinite(array) counts = np.sum(mask, axis=axis) totals = np.nansum(array, axis=axis) out = np.zeros_like(totals, dtype=np.float64) valid = counts > 0 out[valid] = totals[valid] / counts[valid] out[~valid] = np.nan return out