#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author : Lancelot PINCET
# GitHub : https://github.com/LancelotPincet
from smlmlp import block
from arrlp import transform_parameters
from scipy.spatial import cKDTree
import cv2
import numpy as np
[docs]
@block()
def registrate_locs_affine(
x,
y,
ch,
fr,
/,
shape,
ref_pix=1.0,
match_radius_nm=80.0,
min_pairs=30,
max_points_per_channel=20000,
ransac_thresh_nm=20.0,
ransac_confidence=0.995,
ransac_max_iters=5000,
*,
cuda=False,
parallel=False,
):
"""Estimate redundant pairwise affine transforms from localizations.
The function estimates one affine transform per channel pair from matched
localization coordinates. Nearest-neighbor matches are constrained to points
from the same frame.
It returns pairwise transform parameters using the same convention as
``registrate_ecc_affine`` so outputs can be passed to
``registrate_solve_redundant_affine``.
"""
_ = (cuda, parallel)
x = np.asarray(x, dtype=np.float32)
y = np.asarray(y, dtype=np.float32)
ch = np.asarray(ch, dtype=np.int64)
fr = np.asarray(fr)
if x.ndim != 1 or y.ndim != 1 or ch.ndim != 1 or fr.ndim != 1:
raise ValueError("x, y, ch, and fr must be one-dimensional")
if len(x) != len(y) or len(x) != len(ch) or len(x) != len(fr):
raise ValueError("x, y, ch, and fr must have the same length")
if len(shape) != 2:
raise ValueError("shape must contain 2 values (y, x)")
ref_pix = _normalize_ref_pix(ref_pix)
radius_pix = float(match_radius_nm) / ref_pix[1]
ransac_thresh_pix = float(ransac_thresh_nm) / ref_pix[1]
channels = np.unique(ch[ch > 0])
if len(channels) == 0:
raise ValueError("No positive channel ids were provided")
if not np.array_equal(channels, np.arange(1, len(channels) + 1)):
raise ValueError("ch must contain contiguous one-based channel ids")
channels_points = []
channels_frames = []
for channel in channels:
mask = ch == channel
points = np.column_stack(
(
x[mask] / ref_pix[1],
y[mask] / ref_pix[0],
)
).astype(np.float32, copy=False)
frames = np.asarray(fr[mask])
finite_points = np.isfinite(points).all(axis=1)
if np.issubdtype(frames.dtype, np.number):
finite_frames = np.isfinite(frames)
else:
finite_frames = np.ones(len(frames), dtype=bool)
valid = finite_points & finite_frames
points = points[valid]
frames = frames[valid]
if max_points_per_channel is not None and len(points) > max_points_per_channel:
rng = np.random.default_rng(0)
keep = rng.choice(len(points), int(max_points_per_channel), replace=False)
points = points[keep]
frames = frames[keep]
channels_points.append(points)
channels_frames.append(frames)
shiftx = []
shifty = []
angle = []
shearx = []
sheary = []
scalex = []
scaley = []
pairs = []
matrices = []
warp_matrices = []
n_matches = []
n_inliers = []
n_common_frames = []
n_used_frames = []
failed_pairs = []
for i in range(len(channels_points)):
template_pts = channels_points[i]
template_fr = channels_frames[i]
for j in range(i + 1, len(channels_points)):
image_pts = channels_points[j]
image_fr = channels_frames[j]
pairs.append((i, j))
(
fit_matrix,
warp_matrix,
pair_nmatches,
pair_ninliers,
pair_ncommon_frames,
pair_nused_frames,
) = _fit_pair_affine(
template_pts,
image_pts,
template_fr,
image_fr,
radius_pix=radius_pix,
ransac_thresh_pix=ransac_thresh_pix,
min_pairs=min_pairs,
ransac_confidence=ransac_confidence,
ransac_max_iters=ransac_max_iters,
)
if fit_matrix is None:
fit_matrix = np.eye(3, dtype=float)
warp_matrix = np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=np.float32,
)
failed_pairs.append((i, j))
dx, dy, da, dsx, dsy, scx, scy = _matrix_to_parameters(
fit_matrix,
shape,
ref_pix,
)
shiftx.append(dx)
shifty.append(dy)
angle.append(da)
shearx.append(dsx)
sheary.append(dsy)
scalex.append(scx)
scaley.append(scy)
matrices.append(fit_matrix)
warp_matrices.append(warp_matrix)
n_matches.append(pair_nmatches)
n_inliers.append(pair_ninliers)
n_common_frames.append(pair_ncommon_frames)
n_used_frames.append(pair_nused_frames)
info = {
"ref_pix": ref_pix,
"shape": tuple(int(v) for v in shape),
"pairs": pairs,
"matrices": matrices,
"warp_matrices": warp_matrices,
"n_matches": n_matches,
"n_inliers": n_inliers,
"n_common_frames": n_common_frames,
"n_used_frames": n_used_frames,
"failed_pairs": failed_pairs,
"frame_constrained": True,
"match_radius_nm": float(match_radius_nm),
"ransac_thresh_nm": float(ransac_thresh_nm),
}
return shiftx, shifty, angle, shearx, sheary, scalex, scaley, info
def _fit_pair_affine(
template_pts,
image_pts,
template_fr,
image_fr,
*,
radius_pix,
ransac_thresh_pix,
min_pairs,
ransac_confidence,
ransac_max_iters,
):
"""Fit one affine transform from channel j points to channel i points."""
if len(template_pts) < min_pairs or len(image_pts) < min_pairs:
return None, None, 0, 0, 0, 0
template_fr = np.asarray(template_fr)
image_fr = np.asarray(image_fr)
common_frames = np.intersect1d(
np.unique(template_fr),
np.unique(image_fr),
assume_unique=False,
)
if len(common_frames) == 0:
return None, None, 0, 0, 0, 0
src_match = []
dst_match = []
used_frames = 0
for frame in common_frames:
template_idx = np.where(template_fr == frame)[0]
image_idx = np.where(image_fr == frame)[0]
if len(template_idx) == 0 or len(image_idx) == 0:
continue
template_frame_pts = template_pts[template_idx]
image_frame_pts = image_pts[image_idx]
if len(template_frame_pts) == 0 or len(image_frame_pts) == 0:
continue
tree_i = cKDTree(template_frame_pts)
tree_j = cKDTree(image_frame_pts)
d_ij, nn_i = tree_i.query(image_frame_pts, k=1, distance_upper_bound=radius_pix)
valid_j = np.isfinite(d_ij)
if not np.any(valid_j):
continue
j_local = np.where(valid_j)[0]
i_local = nn_i[valid_j].astype(np.int64)
d_ji, nn_j = tree_j.query(
template_frame_pts[i_local],
k=1,
distance_upper_bound=radius_pix,
)
mutual = np.isfinite(d_ji) & (nn_j.astype(np.int64) == j_local)
if not np.any(mutual):
continue
src_match.append(image_frame_pts[j_local[mutual]])
dst_match.append(template_frame_pts[i_local[mutual]])
used_frames += 1
if len(src_match) == 0:
return None, None, 0, 0, len(common_frames), 0
src = np.ascontiguousarray(np.concatenate(src_match, axis=0), dtype=np.float32)
dst = np.ascontiguousarray(np.concatenate(dst_match, axis=0), dtype=np.float32)
if len(src) < min_pairs:
return None, None, len(src), 0, len(common_frames), used_frames
warp_matrix, inlier_mask = cv2.estimateAffine2D(
src,
dst,
method=cv2.RANSAC,
ransacReprojThreshold=ransac_thresh_pix,
maxIters=int(ransac_max_iters),
confidence=float(ransac_confidence),
refineIters=10,
)
if warp_matrix is None:
return None, None, len(src), 0, len(common_frames), used_frames
ninliers = int(np.count_nonzero(inlier_mask)) if inlier_mask is not None else len(src)
fit_matrix = _cv2_to_ndimage_matrix(warp_matrix)
return fit_matrix, warp_matrix, len(src), ninliers, len(common_frames), used_frames
def _normalize_ref_pix(ref_pix):
"""Normalize pixel size to a (y, x) pair."""
try:
if len(ref_pix) != 2:
raise ValueError("ref_pix does not have 2 values (y, x)")
except TypeError:
ref_pix = (ref_pix, ref_pix)
return float(ref_pix[0]), float(ref_pix[1])
def _cv2_to_ndimage_matrix(warp_matrix):
"""Convert an OpenCV x/y affine warp to the ndimage y/x convention."""
warp_matrix = np.asarray(warp_matrix, dtype=float)
matrix = np.eye(3, dtype=float)
matrix[0, 0] = warp_matrix[1, 1]
matrix[0, 1] = warp_matrix[1, 0]
matrix[0, 2] = warp_matrix[1, 2]
matrix[1, 0] = warp_matrix[0, 1]
matrix[1, 1] = warp_matrix[0, 0]
matrix[1, 2] = warp_matrix[0, 2]
return matrix
def _matrix_to_parameters(matrix, shape, ref_pix):
"""Recover transform parameters from an affine matrix."""
rotation = _polar_rotation_angle(matrix[:2, :2])
dx, dy, dsx, dsy, da, scx, scy = transform_parameters(
matrix,
shape,
angle=rotation,
)
return dx * ref_pix[1], dy * ref_pix[0], da, dsx, dsy, scx, scy
def _polar_rotation_angle(linear):
"""Return the closest proper-rotation angle for a 2x2 affine matrix."""
u, _, vt = np.linalg.svd(linear)
rotation = u @ vt
if np.linalg.det(rotation) < 0:
u[:, -1] *= -1
rotation = u @ vt
theta = np.arctan2(rotation[1, 0], rotation[0, 0])
return -np.degrees(theta)