Source code for smlmlp.modules.analysis_LP._functions.associate.associate_different_channels

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author        : Lancelot PINCET
# GitHub        : https://github.com/LancelotPincet


import os

from joblib import Parallel, delayed
import numpy as np
from scipy.spatial import cKDTree
from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import connected_components
from smlmlp import analysis


[docs] @analysis(df_name="detections") def associate_different_channels( x_globfit, y_globfit, fr, ch, association_radius_nm=30.0, *, cuda=False, parallel=False, ): """ Associate same-frame localizations from different channels. Parameters ---------- x_globfit, y_globfit : array-like Localization coordinates. fr : array-like Frame identifiers. ch : array-like Channel identifiers. association_radius_nm : float, optional Maximum spatial association radius. cuda, parallel : bool or int, optional Execution options accepted by all analysis functions. Returns ------- point : ndarray uint64 Point index per localization. 0 means unassigned/problematic. info : dict Diagnostics. """ if parallel is False or parallel == 0 or parallel == 1: n_jobs = 1 elif parallel is True or parallel == -1: n_jobs = max(1, (os.cpu_count() or 2) - 1) elif isinstance(parallel, int) and parallel > 1: n_jobs = parallel else: raise ValueError("Invalid value for 'parallel'") x = np.asarray(x_globfit, dtype=np.float32) y = np.asarray(y_globfit, dtype=np.float32) fr = np.asarray(fr, dtype=np.uint32) ch = np.asarray(ch, dtype=np.uint8) required_channels = np.unique(ch) if required_channels[0] == 0 : required_channels = required_channels[1:] n = len(fr) if n == 0: info = { "n_ambiguous_components": 0, "n_groups": 0, "max_component_size": 0, } return np.empty(0, dtype=np.uint64), info radius = float(association_radius_nm) radius2 = radius * radius order = np.argsort(fr, kind="stable") xs = x[order] ys = y[order] fs = fr[order] cs = ch[order] unique, start, counts = np.unique( fs, return_index=True, return_counts=True, ) def _associate_one_frame(k): """Associate localizations within one frame.""" a0 = start[k] a1 = a0 + counts[k] xf = xs[a0:a1] yf = ys[a0:a1] cf = cs[a0:a1] n_frame = len(xf) point_local = np.zeros(n_frame, dtype=np.int64) stats = { "n_candidate_edges": 0, "n_components": 0, "n_missing_components": 0, "n_ambiguous_components": 0, "n_valid_components": 0, "n_groups": 0, "max_component_size": 0, } if n_frame == 0: return point_local, stats pts = np.column_stack((xf, yf)) tree = cKDTree(pts) pairs = tree.query_pairs(radius, output_type="ndarray") if len(pairs) == 0: return point_local, stats keep = cf[pairs[:, 0]] != cf[pairs[:, 1]] pairs = pairs[keep] if len(pairs) == 0: return point_local, stats stats["n_candidate_edges"] = len(pairs) graph = coo_matrix( ( np.ones(len(pairs) * 2, dtype=np.uint8), ( np.concatenate((pairs[:, 0], pairs[:, 1])), np.concatenate((pairs[:, 1], pairs[:, 0])), ), ), shape=(n_frame, n_frame), ) _, labels = connected_components(graph, directed=False) used_nodes = np.unique(pairs.ravel()) component_ids = np.unique(labels[used_nodes]) stats["n_components"] = len(component_ids) current_point = 1 for cid in component_ids: idx = np.where(labels == cid)[0] stats["max_component_size"] = max( stats["max_component_size"], len(idx), ) if len(idx) < len(required_channels): stats["n_missing_components"] += 1 continue comp_x = xf[idx] comp_y = yf[idx] comp_ch = cf[idx] groups, component_status = _solve_multichannel_component( comp_x, comp_y, comp_ch, required_channels, radius2, ) if component_status == 0: stats["n_missing_components"] += 1 continue if component_status == 2: stats["n_ambiguous_components"] += 1 if len(groups) == 0: continue stats["n_valid_components"] += 1 stats["n_groups"] += len(groups) for g in range(len(groups)): for j in range(groups.shape[1]): point_local[idx[groups[g, j]]] = current_point current_point += 1 return point_local, stats if n_jobs == 1: per_frame = [_associate_one_frame(k) for k in range(len(unique))] else: per_frame = Parallel(n_jobs=n_jobs, prefer="threads")( delayed(_associate_one_frame)(k) for k in range(len(unique)) ) point_sorted = np.zeros(n, dtype=np.int64) offset = 0 stats_all = [] for k, (point_local, stats) in enumerate(per_frame): stats_all.append(stats) s = start[k] e = s + counts[k] mask = point_local > 0 point_local_global = point_local.copy() point_local_global[mask] += offset point_sorted[s:e] = point_local_global if point_local.max() > 0: offset += point_local.max() point = np.zeros_like(point_sorted) point[order] = point_sorted info = { "n_ambiguous_components": int(sum(s["n_ambiguous_components"] for s in stats_all)), "n_groups": int(sum(s["n_groups"] for s in stats_all)), "max_component_size": int(max(s["max_component_size"] for s in stats_all)), } return point.astype(np.uint64), info
def _solve_multichannel_component( x, y, ch, required_channels, radius2, ): """Solve one multichannel connected component.""" n_required = len(required_channels) channel_indices = [] ambiguous = False for channel in required_channels: idx = np.where(ch == channel)[0] if len(idx) == 0: return np.zeros((0, n_required), dtype=np.int64), 0 if len(idx) > 1: ambiguous = True channel_indices.append(idx) groups = [] costs = [] current = np.empty(n_required, dtype=np.int64) def build(level): """Build valid groups recursively.""" if level == n_required: cx = 0.0 cy = 0.0 for i in range(n_required): cx += x[current[i]] cy += y[current[i]] cx /= n_required cy /= n_required cost = 0.0 for i in range(n_required): dx = x[current[i]] - cx dy = y[current[i]] - cy cost += dx * dx + dy * dy groups.append(current.copy()) costs.append(cost) return for candidate in channel_indices[level]: valid = True for prev_level in range(level): previous = current[prev_level] dx = x[candidate] - x[previous] dy = y[candidate] - y[previous] if dx * dx + dy * dy > radius2: valid = False break if not valid: continue current[level] = candidate build(level + 1) build(0) if len(groups) == 0: return np.zeros((0, n_required), dtype=np.int64), 0 groups = np.asarray(groups, dtype=np.int64) costs = np.asarray(costs, dtype=np.float64) if not ambiguous: return groups[:1], 1 order = np.argsort(costs) used = np.zeros(len(x), dtype=bool) selected = [] for gidx in order: group = groups[gidx] if np.any(used[group]): continue selected.append(group.copy()) used[group] = True if len(selected) == 0: return np.zeros((0, n_required), dtype=np.int64), 2 return np.asarray(selected, dtype=np.int64), 2