Source code for wf_psf.data.safe_batch

"""Safe batch processing utilities.

This module provides utilities for filtering batches of aligned arrays
based on sample-wise validity criteria, typically derived from an
anchor array (e.g. centroid coordinates).

The core functionality ensures that all dataset components (images,
masks, metadata, etc.) remain aligned when invalid samples (NaNs, Infs)
are removed. This is critical for preventing silent misalignment bugs
in downstream processing.

It also provides lightweight logging helpers to track which samples
were filtered, improving traceability and debugging.

These utilities are intended for use during data preparation stages,
particularly after feature extraction steps such as centroid estimation.

Author(s): Jennifer Pollack <jennifer.pollack@cea.fr>
"""

from collections.abc import Sequence
import numbers
import numpy as np
from typing import Any


def _compute_valid_mask(anchor: np.ndarray) -> np.ndarray:
    """Return a boolean mask indicating which samples contain only finite values.

    A sample is considered valid if all its values are finite
    (i.e., not NaN or ±Inf), as determined by ``np.isfinite``.

    Parameters
    ----------
    anchor : np.ndarray
        Input array of shape (N,) or (N, D), where N is the number of samples.

    Returns
    -------
    np.ndarray
        Boolean mask of shape (N,). For 1D input, each element indicates whether
        the corresponding value is finite. For 2D input, each element indicates
        whether all values in the corresponding row are finite.
    """
    if anchor.ndim == 1:
        return np.isfinite(anchor)
    return np.isfinite(anchor).all(axis=1)


[docs] def safe_batch_builder( anchor: np.ndarray, **arrays: dict[str, Any], ) -> tuple[np.ndarray, dict[str, Any]]: """Filter aligned arrays using a validity mask derived from an anchor array. This strict version enforces that all array-like inputs are aligned along their first dimension. Any mismatch raises an error. Parameters ---------- anchor : np.ndarray Array used to compute validity (typically centroids), of shape (N,) or (N, D). **arrays : dict of str to Any Arrays associated with each sample. All NumPy arrays must have length N. Returns ------- mask : np.ndarray Boolean mask of shape (N,) indicating valid samples. filtered : dict Dictionary containing filtered arrays. Raises ------ ValueError If any array has a length different from N. TypeError If an unsupported type is passed. """ mask = _compute_valid_mask(anchor) if mask.ndim != 1: raise ValueError("Validity mask must be 1D.") n = len(mask) if n == 0: raise ValueError("Empty anchor array.") if not mask.any(): raise ValueError("All samples were filtered out.") filtered = {} for key, arr in arrays.items(): if arr is None: filtered[key] = None continue # --- Array-like objects --- if hasattr(arr, "__array__"): arr = np.asarray(arr) if len(arr) != n: raise ValueError(...) if arr.dtype == object: raise TypeError( f"Array '{key}' has dtype=object, which is not supported." ) filtered[key] = arr[mask] # --- Sequences (lists, tuples, etc.) --- elif isinstance(arr, Sequence) and not isinstance(arr, (str, bytes)): if len(arr) == n: # Treat as aligned → filter filtered[key] = [arr[i] for i in range(n) if mask[i]] else: raise ValueError( f"Sequence '{key}' has length {len(arr)} but expected {n}." ) # --- Scalars --- elif isinstance(arr, numbers.Number): filtered[key] = arr else: raise TypeError(f"Unsupported type for '{key}': {type(arr)}.") return mask, filtered
[docs] def log_filtered_objects(mask, obj_ids, logger, context=""): """Log identifiers of samples removed by a validity mask. Parameters ---------- mask : np.ndarray Boolean mask of shape (N,) indicating valid samples. obj_ids : array-like Identifiers aligned with the samples (length N). logger : logging.Logger Logger instance used for reporting. context : str, optional Additional context string appended to log messages. """ n_removed = (~mask).sum() if n_removed > 0: logger.warning(f"{n_removed} samples removed ({context})") logger.debug(f"Removed object_ids: {obj_ids[~mask].tolist()}")