Source code for smlmlp.modules.block_LP._functions.registration.registrate_transform
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author : Lancelot PINCET
# GitHub : https://github.com/LancelotPincet
from smlmlp import block
from arrlp import img_transform, transform_matrix
import numpy as np
[docs]
@block(timeit=False)
def registrate_transform(
channels, /,
channels_x_shifts_nm=None, channels_y_shifts_nm=None,
channels_rotations_deg=None, channels_x_shears=None, channels_y_shears=None,
*,
channels_pixels_nm=1.0,
cuda=False, parallel=False,
):
"""Apply per-channel affine transforms in place to channel stacks."""
# Normalize per-channel pixel sizes to (y, x) tuples.
channels_pixels_nm = _normalize_channels_pixels_nm(channels_pixels_nm, len(channels))
# Normalize per-channel transformation parameters.
channels_x_shifts_nm = _normalize_channels_parameter(channels_x_shifts_nm, len(channels), "channels_x_shifts_nm")
channels_y_shifts_nm = _normalize_channels_parameter(channels_y_shifts_nm, len(channels), "channels_y_shifts_nm")
channels_rotations_deg = _normalize_channels_parameter(channels_rotations_deg, len(channels), "channels_rotations_deg")
channels_x_shears = _normalize_channels_parameter(channels_x_shears, len(channels), "channels_x_shears")
channels_y_shears = _normalize_channels_parameter(channels_y_shears, len(channels), "channels_y_shears")
# Build one transform matrix per channel and apply each transform in place.
transformed = []
for i, channel in enumerate(channels):
matrix = transform_matrix(
channel,
shiftx=channels_x_shifts_nm[i] / channels_pixels_nm[i][1],
shifty=channels_y_shifts_nm[i] / channels_pixels_nm[i][0],
angle=channels_rotations_deg[i],
shearx=channels_x_shears[i],
sheary=channels_y_shears[i],
)
transformed_channel = img_transform(
channel,
matrix=matrix,
stacks=True,
out=channel,
cuda=cuda,
parallel=parallel,
)
transformed.append(transformed_channel)
return transformed
def _normalize_channels_pixels_nm(channels_pixels_nm, n_channels):
"""Normalize pixel sizes to one (y, x) tuple per channel."""
try:
if len(channels_pixels_nm) != n_channels:
if len(channels_pixels_nm) == 2:
channels_pixels_nm = [channels_pixels_nm for _ in range(n_channels)]
else:
raise ValueError("channels_pixels_nm does not have the same length as channels")
except TypeError:
channels_pixels_nm = [(channels_pixels_nm, channels_pixels_nm) for _ in range(n_channels)]
return channels_pixels_nm
def _normalize_channels_parameter(values, n_channels, name):
"""Normalize one per-channel transform parameter to float32 values."""
if values is None:
return np.zeros(n_channels, dtype=np.float32)
values = np.asarray(values, dtype=np.float32)
if values.ndim == 0:
return np.full(n_channels, values.item(), dtype=np.float32)
if len(values) != n_channels:
raise ValueError(f"{name} does not have the same length as channels")
return values