#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author : Lancelot PINCET
# GitHub : https://github.com/LancelotPincet
from smlmlp import block
from arrlp import img_crosscorr, get_xp
from funclp import Gaussian
from scipy.optimize import curve_fit
import numpy as np
[docs]
@block()
def registrate_pcc_shift(
optimized,
/,
ref_pix=1.0,
fit_window=11,
*,
cuda=False,
parallel=False,
):
"""
Estimate redundant pairwise shifts from phase cross-correlation images.
This function computes the phase cross-correlation between every pair of
optimized channels and estimates a subpixel shift for each frame from the
cross-correlation peak.
Parameters
----------
optimized : sequence of ndarray
Sequence of optimized image, one per channel.
ref_pix : float or tuple of float, optional
Reference pixel size used to convert shifts to physical units. If a
scalar is provided, it is applied to both y and x as ``(ref_pix,
ref_pix)``.
fit_window : int, optional
Half-size of the square crop used for Gaussian peak fitting around
the integer PCC maximum. The fitted crop side is ``2*fit_window + 1``.
Can be passed positionally or by keyword.
cuda : bool, optional
Whether to enable CUDA processing.
parallel : bool, optional
Whether to enable parallel processing.
Returns
-------
tuple
A tuple ``(CC, shiftx, shifty, info)`` where:
- ``CC`` is the list of phase cross-correlation stacks for all channel
pairs,
- ``shiftx`` is the list of per-frame x shifts for all channel pairs,
- ``shifty`` is the list of per-frame y shifts for all channel pairs,
- ``info`` is a dictionary containing reusable intermediate results.
The dictionary contains the following keys:
``'ref_pix'``
Reference pixel size used for shift conversion.
``'pairs'``
List of channel index pairs ``(i, j)`` corresponding to the
cross-correlation and shift outputs.
Notes
-----
For each channel pair, the peak of the phase cross-correlation is detected
on each frame. A local Gaussian fit is then performed on x/y profiles of a
square crop around that maximum to estimate a robust subpixel offset.
Returned shifts follow ``registrate_transform`` convention, i.e. they are
corrective shifts to apply to channel ``j`` in pair ``(i, j)``.
Examples
--------
>>> import numpy as np
>>> ch1 = np.random.rand(16, 16).astype(np.float32)
>>> ch2 = np.random.rand(16, 16).astype(np.float32)
>>> CC, shiftx, shifty, info = registrate_pcc_shift([ch1, ch2])
>>> len(CC)
1
>>> info["pairs"]
[(0, 1)]
>>> ch3 = np.random.rand(16, 16).astype(np.float32)
>>> CC, shiftx, shifty, info = registrate_pcc_shift(
... [ch1, ch2, ch3],
... ref_pix=(100.0, 120.0),
... )
>>> len(info["pairs"])
3
"""
# Select the array backend matching the requested execution mode.
xp = get_xp(cuda)
# Normalize the reference pixel size to a (y, x) pair.
try:
if len(ref_pix) != 2:
raise ValueError("ref_pix does not have 2 values (y, x)")
except TypeError:
ref_pix = (ref_pix, ref_pix)
fit_window = int(max(1, fit_window))
# Compute the phase cross-correlation for each channel pair and estimate
# the corresponding per-frame shifts.
CC = []
shiftx = []
shifty = []
pairs = []
for i in range(len(optimized)):
for j in range(i + 1, len(optimized)):
cc = img_crosscorr(
optimized[i],
optimized[j],
phase=True,
cuda=cuda,
parallel=parallel,
)
if cuda:
cc = xp.asnumpy(cc)
dx, dy = subpixel_peak_stack(cc, ref_pix=ref_pix, fit_window=fit_window)
CC.append(cc)
shiftx.append(dx)
shifty.append(dy)
pairs.append((i, j))
info = {
"ref_pix": ref_pix,
"fit_window": fit_window,
"pairs": pairs,
"CC": CC,
}
return shiftx, shifty, info
def subpixel_peak_stack(c, ref_pix=(1.0, 1.0), fit_window=11):
"""Estimate one subpixel peak position per frame from a CC stack."""
ny, nx = c.shape
iy, ix = np.unravel_index(int(np.argmax(c)), c.shape)
dy_sub, dx_sub = _subpixel_peak_gaussian(c, iy, ix, fit_window=fit_window)
# Return corrective shifts matching registrate_transform convention.
dy = -((iy - ny // 2) + dy_sub) * ref_pix[0]
dx = -((ix - nx // 2) + dx_sub) * ref_pix[1]
return dx, dy
def _subpixel_peak_gaussian(c, iy, ix, fit_window=11):
"""Estimate a subpixel PCC peak with a local 2D Gaussian fit."""
half = int(max(1, fit_window))
ny, nx = c.shape
y0 = max(iy - half, 0)
y1 = min(iy + half + 1, ny)
x0 = max(ix - half, 0)
x1 = min(ix + half + 1, nx)
win = np.asarray(c[y0:y1, x0:x1], dtype=np.float64)
if win.shape[0] < 3 or win.shape[1] < 3:
return 0.0, 0.0
mask = np.isfinite(win)
if not np.any(mask):
return 0.0, 0.0
# Exclude the central horizontal/vertical axes during Gaussian fitting to
# reduce FFT line artifacts, as done in blink_spatial_psf.
cy = int(np.clip(iy - y0, 0, win.shape[0] - 1))
cx = int(np.clip(ix - x0, 0, win.shape[1] - 1))
fit_mask = mask.copy()
fit_mask[cy, :] = False
fit_mask[:, cx] = False
if np.count_nonzero(fit_mask) < 3:
fit_mask = mask
fit_win = np.where(fit_mask, win, np.nan)
x = np.arange(x0, x1, dtype=np.float64) - float(ix)
y = np.arange(y0, y1, dtype=np.float64) - float(iy)
profile_x = _nanmean_no_warn(fit_win, axis=0)
profile_y = _nanmean_no_warn(fit_win, axis=1)
mux = _fit_profile_gaussian(x, profile_x, fit_window=half)
muy = _fit_profile_gaussian(y, profile_y, fit_window=half)
return muy, mux
def _fit_profile_gaussian(coords, values, fit_window):
"""Fit a 1D Gaussian profile and return the fitted center."""
values = np.asarray(values, dtype=np.float64)
coords = np.asarray(coords, dtype=np.float64)
mask = np.isfinite(values) & np.isfinite(coords)
if np.count_nonzero(mask) < 3:
return 0.0
x = coords[mask]
y = values[mask]
ymin = float(np.min(y))
ymax = float(np.max(y))
amp0 = max(ymax - ymin, 1e-6)
sig0 = max((x.max() - x.min()) / 4.0, 0.75)
gaus = Gaussian(pix=1.0, sig=sig0)
func2fit = lambda xi, mu, sig, amp, offset: gaus(
xi,
mu=mu,
sig=sig,
amp=amp,
offset=offset,
)
lower = [float(x.min()), 0.5, 0.0, ymin - amp0]
upper = [float(x.max()), float(max(3.0, x.max() - x.min())), 2.0 * amp0, ymax + amp0]
p0 = [0.0, sig0, amp0, ymin]
try:
popt, _ = curve_fit(
func2fit,
x,
y,
p0=p0,
bounds=(lower, upper),
maxfev=4000,
)
mu = float(popt[0])
except Exception:
return 0.0
if not np.isfinite(mu):
return 0.0
lo = max(float(x.min()), -float(fit_window))
hi = min(float(x.max()), float(fit_window))
return float(np.clip(mu, lo, hi))
def _nanmean_no_warn(array, axis):
"""Compute nanmean along an axis without RuntimeWarning on empty slices."""
mask = np.isfinite(array)
counts = np.sum(mask, axis=axis)
totals = np.nansum(array, axis=axis)
out = np.zeros_like(totals, dtype=np.float64)
valid = counts > 0
out[valid] = totals[valid] / counts[valid]
out[~valid] = np.nan
return out