Source code for wf_psf.data.data_adapter

"""
Data Adapter.

This module manages dataset lifecycle transitions for the WF-PSF pipeline.

Overview
--------
Two orthogonal state machines are maintained:

**1. Structure state**

- ``COMPLETE`` → ``SPLIT`` via :meth:`split_data`
- ``SPLIT`` → ``COMPLETE`` via :meth:`join_data`

**2. Representation state**

- ``NUMPY`` → ``TENSORFLOW`` via :meth:`convert_to_tensorflow`


Glossary
--------
**COMPLETE**
    Dataset stored as a single container.

**SPLIT**
    Dataset stored as train/test subsets.

**NUMPY**
    Data stored as NumPy arrays.

**TENSORFLOW**
    Data stored as TensorFlow tensors.


Design principles
-----------------
- Structure and representation are orthogonal.
- All transitions are explicit and idempotent where possible.
- No training or model logic lives in this module.
- Dataset field names are canonicalized for downstream models.

Notes
-----
The :class:`DataAdapter` class manages these transitions while providing
a consistent interface for accessing dataset contents.

Authors: Jennifer Pollack <jennifer.pollack@cea.fr>
"""

from enum import Enum, auto
from re import S
import numpy as np
from typing import Any, Optional
from wf_psf.data.constants import (
    CANONICAL_DATASET_KEYS,
    DEFAULT_SEED,
    DEFAULT_TRAIN_FRACTION,
    DATASET_INDEX_KEY,
)
from wf_psf.data.data_utils import DatasetContainer, to_container
from wf_psf.data.tensorflow_converter import TensorFlowDatasetConverter
import logging

logger = logging.getLogger(__name__)


[docs] class StructureState(Enum): """Structural state of the dataset. - COMPLETE: the dataset is complete and not split into train/test. - SPLIT: the dataset is split into train and test sets. """ COMPLETE = auto() SPLIT = auto()
[docs] class RepresentationState(Enum): """Representation state of the dataset. - NUMPY: the dataset is represented as a NumPy array. - TENSORFLOW: the dataset is represented as a TensorFlow tensor. """ NUMPY = auto() TENSORFLOW = auto()
[docs] class LoadedDataset: """Structured container for loaded dataset. Attributes ---------- complete : dict, optional The complete dataset (if in COMPLETE state). train : dict, optional The training dataset (if in SPLIT state). test : dict, optional The test dataset (if in SPLIT state). """ def __init__( self, complete: Optional[dict] = None, train: Optional[dict] = None, test: Optional[dict] = None, ): """Initialize the LoadedDataset with either complete or split data.""" self.complete = complete self.train = train self.test = test
[docs] def is_complete(self) -> bool: """Check if the dataset is in COMPLETE state.""" return self.complete is not None
[docs] def is_split(self) -> bool: """Check if the dataset is in SPLIT state.""" return self.train is not None and self.test is not None
[docs] class DataAdapter: """Adapter for managing dataset structure and backend representation. The adapter provides a consistent interface to datasets regardless of whether they are stored as a complete dataset or as train/test splits, and whether the underlying representation is NumPy or TensorFlow. It also canonicalizes dataset fields to the names expected by downstream models. Notes ----- Instances should be created via `DataAdapterFactory.build()`. """ def __init__( self, dataset: LoadedDataset, converter: TensorFlowDatasetConverter, params: Optional[Any] = None, metadata: Optional[dict] = None, ): """Initialize the DataAdapter with a loaded dataset and optional converter and parameters. Parameters ---------- dataset : LoadedDataset The loaded dataset containing either complete or data split into train and test (validation) datasets. converter : TensorFlowDatasetConverter A TensorFlowDatasetConverter object for transforming from NumPy to TensorFlow representations. params : optional Additional parameters for dataset management. metadata : optional Additional ancillary information for dataset management. Raises ------ ValueError If the loaded dataset is empty or malformed (neither complete nor split). """ if dataset is None: raise ValueError("LoadedDataset cannot be None.") self._params = params self._metadata = metadata self._converter = converter self._canonical_keys = getattr(params, "canonical_keys", CANONICAL_DATASET_KEYS) self._train_fraction = getattr(params, "train_fraction", DEFAULT_TRAIN_FRACTION) self._seed = getattr(params, "seed", DEFAULT_SEED) # Dataset containers for both representations # default empty self._train_data = None self._test_data = None self._complete_data = None # Determine Structure state if dataset.is_complete(): self._structure_state = StructureState.COMPLETE elif dataset.is_split(): self._structure_state = StructureState.SPLIT else: raise ValueError("Loaded dataset is empty or malformed") # Convert to containers self._initialize_structure(dataset) # Map user-provided keys to canonical keys self._canonicalize_initial_data() # Representation state always starts as NUMPY self._representation_state = RepresentationState.NUMPY # Placeholders for TensorFlow datasets self._train_tf = None self._test_tf = None self._complete_tf = None def _initialize_structure(self, dataset): """Initialize internal data containers based on structure state. Copies input dataset into internal container representation, handling either COMPLETE or SPLIT structure modes. """ if self._structure_state == StructureState.COMPLETE: self._complete_data = to_container(dataset.complete.copy()) return self._train_data = to_container(dataset.train.copy()) self._test_data = to_container(dataset.test.copy()) def _resolve_target_field(self, split=None): """Resolve the target field name from adapter parameters. Resolution order: 1. Split-specific field (if `split` provided) 2. Global `target_field` 3. `complete.target_field` 4. Default to "sources" Returns ------- str Name of the field to use as the source/target data. """ p = self._params if p is None: # params is shallow return "sources" if split and hasattr(p, split): return getattr(getattr(p, split), "target_field") if hasattr(p, "target_field"): return p.target_field if hasattr(p, "complete"): return getattr(p.complete, "target_field") return "sources" def _canonicalize_initial_data(self): """Canonicalize all available dataset splits. Applies field normalization to train, test, and complete containers using the resolved target field for each split. """ for split, container in { "train": self._train_data, "test": self._test_data, None: self._complete_data, }.items(): if container is not None: self._canonicalize_container( container, self._resolve_target_field(split) ) def _canonicalize_container(self, container, target_field): """Canonicalize dataset fields in-place. Maps dataset-specific keys to canonical naming convention. Notes ----- - Positions and masks are assumed to be standard and are not remapped. - Missing canonical fields are set to None with a warning. """ # Rename keys (case-insensitive mapping) for key in self._canonical_keys: if key in container: continue # already canonical for legacy_key in container.keys(): if legacy_key.lower() == key.lower(): container[key] = container.pop(legacy_key) break # Map target_field to sources if target_field in container: container["sources"] = container.pop(target_field) else: logger.warning( f"Target field '{target_field}' not found. " f"Available fields: {list(container.keys())}" ) return container @property def structure_state(self): """Return the current structural state of the dataset.""" return self._structure_state @property def representation_state(self): """Return the current representation state of the dataset.""" return self._representation_state @property def complete_data(self): """Return the complete dataset in the current representation.""" if self._representation_state == RepresentationState.TENSORFLOW: return self._complete_tf or self._complete_data return self._complete_data @property def train_data(self): """Return the training set in the current representation.""" if self._representation_state == RepresentationState.TENSORFLOW: return self._train_tf or self._train_data return self._train_data @property def test_data(self): """Return the test set in the current representation.""" if self._representation_state == RepresentationState.TENSORFLOW: return self._test_tf or self._test_data return self._test_data # Read access to data params and metadata @property def params(self) -> Optional[Any]: """Get dataset params.""" return self._params @property def metadata(self) -> Optional[dict]: """Get dataset metadata.""" return self._metadata # Convenient access to canonical fields for downstream models @property def sources(self): """Get sources for the complete dataset.""" if self._complete_data is None: return None return self._complete_data.get("sources", None) @property def positions(self): """Get positions for the complete dataset.""" if self._complete_data is None: return None return self._complete_data.get("positions", None) @property def masks(self): """Get masks for the complete dataset.""" if self._complete_data is None: return None return self._complete_data.get("masks", None) @property def zernike_prior(self): """Get Zernike prior for the complete dataset.""" if self._complete_data is None: return None return self._complete_data.get("zernike_prior", None)
[docs] def split_data(self, ratio: Optional[float] = None, seed: Optional[int] = None): """Split the complete dataset into train and test sets if not already split. Parameters ---------- ratio : float, optional The fraction of the dataset to use for training (default is 0.8 or from params). seed : int, optional The random seed for reproducibility (default is from params). Raises ------ RuntimeError If the dataset is not in COMPLETE state when attempting to split. Notes ----- - Splitting is idempotent: if the dataset is already in SPLIT state, this method does not modify the data or re-split the dataset. """ if self._structure_state != StructureState.COMPLETE: raise RuntimeError("Split only allowed from COMPLETE state.") # ------------------------------- # Consistency enforcement # ------------------------------- if (self._train_data is None) != (self._test_data is None): raise RuntimeError("Inconsistent split state: train/test mismatch.") # Idempotent behaviour if self._train_data is None and self._test_data is None: ratio = ratio if ratio is not None else self._train_fraction logger.info(f"Setting train_fraction to {ratio}...") seed = seed if seed is not None else self._seed logger.info(f"Setting random seed for train-test split to {seed}....") self._train_data, self._test_data = self._split( self._complete_data, ratio=ratio, seed=seed ) else: logger.info( "Split data already exists; reusing existing train/test split (idempotent call)." ) self._structure_state = StructureState.SPLIT
[docs] def join_data(self, keys: Optional[list[str]] = None): """Join train and test splits into a single complete dataset. Concatenates corresponding arrays from the train and test containers along the first axis (sample dimension) for the specified keys. Parameters ---------- keys : list of str, optional Dataset fields to join. If None, uses the canonical dataset keys. Raises ------ RuntimeError If the adapter is not in SPLIT state or if train/test data is missing. Notes ----- Only keys present in both train and test datasets are joined. """ if self._structure_state != StructureState.SPLIT: raise RuntimeError("Join only allowed from SPLIT state.") if self._train_data is None or self._test_data is None: raise RuntimeError("Train or test data is missing; cannot join.") # Join on specified keys or CANONICAL_KEYS by default keys_to_join = keys or self._canonical_keys self._complete_data = { k: np.concatenate([self._train_data[k], self._test_data[k]], axis=0) for k in keys_to_join if k in self._train_data and k in self._test_data } self._structure_state = StructureState.COMPLETE
[docs] def convert_to_tensorflow(self, simPSF, n_bins_lambda, mode): """Convert dataset containers from NumPy to TensorFlow representation. Applies the configured converter to transform dataset fields associated with canonical keys into TensorFlow-compatible formats. Conversion is performed on the active structure: - SPLIT: converts train and test datasets separately - COMPLETE: converts the full dataset Parameters ---------- simPSF : PSFSimulator Simulator instance passed to the converter. n_bins_lambda : int Number of wavelength bins used during conversion. mode : DatasetMode Dataset operation mode used to select the appropriate dataset schema for a given pipeline process (e.g. training, validation, inference) Raises ------ RuntimeError If no converter is configured. Notes ----- - Conversion is idempotent: if the data is already in TensorFlow representation, this method does nothing. - Converted datasets are stored in internal attributes (``_train_tf``, ``_test_tf``, ``_complete_tf``) and do not overwrite the original NumPy data. """ if self._representation_state == RepresentationState.TENSORFLOW: return if self._converter is None: raise RuntimeError("No converter provided.") if self._structure_state == StructureState.SPLIT: n_train = len(self._train_data[DATASET_INDEX_KEY]) n_test = len(self._test_data[DATASET_INDEX_KEY]) logger.info( f"Converting training dataset to TensorFlow " f"(mode={mode.name.lower()}, samples={n_train})..." ) self._train_tf = self._converter.convert_dataset( self._train_data, simPSF, n_bins_lambda, mode=mode ) logger.info( f"Converting test dataset to TensorFlow " f"(mode={mode.name.lower()}, samples={n_test})..." ) self._test_tf = self._converter.convert_dataset( self._test_data, simPSF, n_bins_lambda, mode=mode ) else: n_complete = len(self._complete_data[DATASET_INDEX_KEY]) logger.info( f"Converting complete dataset to TensorFlow " f"(mode={mode.name.lower()}, samples={n_complete})..." ) self._complete_tf = self._converter.convert_dataset( self._complete_data, simPSF, n_bins_lambda, mode=mode ) self._representation_state = RepresentationState.TENSORFLOW logger.info("Dataset representation state updated to TensorFlow.")
def _split(self, data, ratio: Optional[float] = None, seed: Optional[int] = None): """Split a dataset container into train and test subsets. The split is performed along the first dimension (sample axis) using a random permutation. Only array entries whose leading dimension matches the number of samples are split; all other entries are copied unchanged. Parameters ---------- data : DatasetContainer Container holding the complete dataset. Expected to store array-like values (e.g. sources, positions, masks, SEDs) indexed by field name. ratio : float, optional Fraction of samples assigned to the training set. If not provided, defaults to ``self._params.train_fraction`` or 0.8. seed : int, optional Random seed used to generate the split. Returns ------- train_data : dict Dictionary containing the training subset. test_data : dict Dictionary containing the test subset. Raises ------ ValueError If the dataset size cannot be inferred from canonical keys. Notes ----- - The dataset size is inferred from the first available canonical key. - Arrays whose leading dimension does not match the inferred size are not split and are copied as-is into both outputs. """ ratio = ratio or getattr(self._params, "train_fraction", 0.8) rng = np.random.default_rng(seed) canonical_keys = self._canonical_keys n = None # Determine sample size from index dataset key n = data[DATASET_INDEX_KEY].shape[0] if n is None: raise ValueError( f"Could not determine dataset length from canonical keys {canonical_keys}" ) n_train = int(n * ratio) indices = rng.permutation(n) train_idx = indices[:n_train] test_idx = indices[n_train:] train_data = {} test_data = {} for k, v in data.items(): if isinstance(v, np.ndarray) and v.shape[0] == n: train_data[k] = v[train_idx] test_data[k] = v[test_idx] else: # leave arrays with different leading axis untouched train_data[k] = v test_data[k] = v return train_data, test_data
[docs] def release_numpy(self): """Release NumPy datasets.""" if self._representation_state != RepresentationState.TENSORFLOW: raise RuntimeError("NumPy can only be released after TF conversion.") self._complete_data = DatasetContainer({}) self._train_data = DatasetContainer({}) self._test_data = DatasetContainer({})
[docs] def release_tensorflow(self): """Release tensorflow datasets.""" self._complete_tf = None self._train_tf = None self._test_tf = None self._representation_state = RepresentationState.NUMPY