Source code for smlmlp.modules.block_LP._functions.registration.registrate_locs_affine

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet



from smlmlp import block
from arrlp import transform_parameters
from scipy.spatial import cKDTree
import cv2
import numpy as np



[docs] @block() def registrate_locs_affine( x, y, ch, fr, /, shape, ref_pix=1.0, match_radius_nm=80.0, min_pairs=30, max_points_per_channel=20000, ransac_thresh_nm=20.0, ransac_confidence=0.995, ransac_max_iters=5000, *, cuda=False, parallel=False, ): """Estimate redundant pairwise affine transforms from localizations. The function estimates one affine transform per channel pair from matched localization coordinates. Nearest-neighbor matches are constrained to points from the same frame. It returns pairwise transform parameters using the same convention as ``registrate_ecc_affine`` so outputs can be passed to ``registrate_solve_redundant_affine``. """ _ = (cuda, parallel) x = np.asarray(x, dtype=np.float32) y = np.asarray(y, dtype=np.float32) ch = np.asarray(ch, dtype=np.int64) fr = np.asarray(fr) if x.ndim != 1 or y.ndim != 1 or ch.ndim != 1 or fr.ndim != 1: raise ValueError("x, y, ch, and fr must be one-dimensional") if len(x) != len(y) or len(x) != len(ch) or len(x) != len(fr): raise ValueError("x, y, ch, and fr must have the same length") if len(shape) != 2: raise ValueError("shape must contain 2 values (y, x)") ref_pix = _normalize_ref_pix(ref_pix) radius_pix = float(match_radius_nm) / ref_pix[1] ransac_thresh_pix = float(ransac_thresh_nm) / ref_pix[1] channels = np.unique(ch[ch > 0]) if len(channels) == 0: raise ValueError("No positive channel ids were provided") if not np.array_equal(channels, np.arange(1, len(channels) + 1)): raise ValueError("ch must contain contiguous one-based channel ids") channels_points = [] channels_frames = [] for channel in channels: mask = ch == channel points = np.column_stack( ( x[mask] / ref_pix[1], y[mask] / ref_pix[0], ) ).astype(np.float32, copy=False) frames = np.asarray(fr[mask]) finite_points = np.isfinite(points).all(axis=1) if np.issubdtype(frames.dtype, np.number): finite_frames = np.isfinite(frames) else: finite_frames = np.ones(len(frames), dtype=bool) valid = finite_points & finite_frames points = points[valid] frames = frames[valid] if max_points_per_channel is not None and len(points) > max_points_per_channel: rng = np.random.default_rng(0) keep = rng.choice(len(points), int(max_points_per_channel), replace=False) points = points[keep] frames = frames[keep] channels_points.append(points) channels_frames.append(frames) shiftx = [] shifty = [] angle = [] shearx = [] sheary = [] scalex = [] scaley = [] pairs = [] matrices = [] warp_matrices = [] n_matches = [] n_inliers = [] n_common_frames = [] n_used_frames = [] failed_pairs = [] for i in range(len(channels_points)): template_pts = channels_points[i] template_fr = channels_frames[i] for j in range(i + 1, len(channels_points)): image_pts = channels_points[j] image_fr = channels_frames[j] pairs.append((i, j)) ( fit_matrix, warp_matrix, pair_nmatches, pair_ninliers, pair_ncommon_frames, pair_nused_frames, ) = _fit_pair_affine( template_pts, image_pts, template_fr, image_fr, radius_pix=radius_pix, ransac_thresh_pix=ransac_thresh_pix, min_pairs=min_pairs, ransac_confidence=ransac_confidence, ransac_max_iters=ransac_max_iters, ) if fit_matrix is None: fit_matrix = np.eye(3, dtype=float) warp_matrix = np.array( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float32, ) failed_pairs.append((i, j)) dx, dy, da, dsx, dsy, scx, scy = _matrix_to_parameters( fit_matrix, shape, ref_pix, ) shiftx.append(dx) shifty.append(dy) angle.append(da) shearx.append(dsx) sheary.append(dsy) scalex.append(scx) scaley.append(scy) matrices.append(fit_matrix) warp_matrices.append(warp_matrix) n_matches.append(pair_nmatches) n_inliers.append(pair_ninliers) n_common_frames.append(pair_ncommon_frames) n_used_frames.append(pair_nused_frames) info = { "ref_pix": ref_pix, "shape": tuple(int(v) for v in shape), "pairs": pairs, "matrices": matrices, "warp_matrices": warp_matrices, "n_matches": n_matches, "n_inliers": n_inliers, "n_common_frames": n_common_frames, "n_used_frames": n_used_frames, "failed_pairs": failed_pairs, "frame_constrained": True, "match_radius_nm": float(match_radius_nm), "ransac_thresh_nm": float(ransac_thresh_nm), } return shiftx, shifty, angle, shearx, sheary, scalex, scaley, info
def _fit_pair_affine( template_pts, image_pts, template_fr, image_fr, *, radius_pix, ransac_thresh_pix, min_pairs, ransac_confidence, ransac_max_iters, ): """Fit one affine transform from channel j points to channel i points.""" if len(template_pts) < min_pairs or len(image_pts) < min_pairs: return None, None, 0, 0, 0, 0 template_fr = np.asarray(template_fr) image_fr = np.asarray(image_fr) common_frames = np.intersect1d( np.unique(template_fr), np.unique(image_fr), assume_unique=False, ) if len(common_frames) == 0: return None, None, 0, 0, 0, 0 src_match = [] dst_match = [] used_frames = 0 for frame in common_frames: template_idx = np.where(template_fr == frame)[0] image_idx = np.where(image_fr == frame)[0] if len(template_idx) == 0 or len(image_idx) == 0: continue template_frame_pts = template_pts[template_idx] image_frame_pts = image_pts[image_idx] if len(template_frame_pts) == 0 or len(image_frame_pts) == 0: continue tree_i = cKDTree(template_frame_pts) tree_j = cKDTree(image_frame_pts) d_ij, nn_i = tree_i.query(image_frame_pts, k=1, distance_upper_bound=radius_pix) valid_j = np.isfinite(d_ij) if not np.any(valid_j): continue j_local = np.where(valid_j)[0] i_local = nn_i[valid_j].astype(np.int64) d_ji, nn_j = tree_j.query( template_frame_pts[i_local], k=1, distance_upper_bound=radius_pix, ) mutual = np.isfinite(d_ji) & (nn_j.astype(np.int64) == j_local) if not np.any(mutual): continue src_match.append(image_frame_pts[j_local[mutual]]) dst_match.append(template_frame_pts[i_local[mutual]]) used_frames += 1 if len(src_match) == 0: return None, None, 0, 0, len(common_frames), 0 src = np.ascontiguousarray(np.concatenate(src_match, axis=0), dtype=np.float32) dst = np.ascontiguousarray(np.concatenate(dst_match, axis=0), dtype=np.float32) if len(src) < min_pairs: return None, None, len(src), 0, len(common_frames), used_frames warp_matrix, inlier_mask = cv2.estimateAffine2D( src, dst, method=cv2.RANSAC, ransacReprojThreshold=ransac_thresh_pix, maxIters=int(ransac_max_iters), confidence=float(ransac_confidence), refineIters=10, ) if warp_matrix is None: return None, None, len(src), 0, len(common_frames), used_frames ninliers = int(np.count_nonzero(inlier_mask)) if inlier_mask is not None else len(src) fit_matrix = _cv2_to_ndimage_matrix(warp_matrix) return fit_matrix, warp_matrix, len(src), ninliers, len(common_frames), used_frames def _normalize_ref_pix(ref_pix): """Normalize pixel size to a (y, x) pair.""" try: if len(ref_pix) != 2: raise ValueError("ref_pix does not have 2 values (y, x)") except TypeError: ref_pix = (ref_pix, ref_pix) return float(ref_pix[0]), float(ref_pix[1]) def _cv2_to_ndimage_matrix(warp_matrix): """Convert an OpenCV x/y affine warp to the ndimage y/x convention.""" warp_matrix = np.asarray(warp_matrix, dtype=float) matrix = np.eye(3, dtype=float) matrix[0, 0] = warp_matrix[1, 1] matrix[0, 1] = warp_matrix[1, 0] matrix[0, 2] = warp_matrix[1, 2] matrix[1, 0] = warp_matrix[0, 1] matrix[1, 1] = warp_matrix[0, 0] matrix[1, 2] = warp_matrix[0, 2] return matrix def _matrix_to_parameters(matrix, shape, ref_pix): """Recover transform parameters from an affine matrix.""" rotation = _polar_rotation_angle(matrix[:2, :2]) dx, dy, dsx, dsy, da, scx, scy = transform_parameters( matrix, shape, angle=rotation, ) return dx * ref_pix[1], dy * ref_pix[0], da, dsx, dsy, scx, scy def _polar_rotation_angle(linear): """Return the closest proper-rotation angle for a 2x2 affine matrix.""" u, _, vt = np.linalg.svd(linear) rotation = u @ vt if np.linalg.det(rotation) < 0: u[:, -1] *= -1 rotation = u @ vt theta = np.arctan2(rotation[1, 0], rotation[0, 0]) return -np.degrees(theta)