wf_psf.data.data_handler

Data Handler Module.

Provides tools for loading, preprocessing, and managing data used in both training and inference workflows.

Includes:

  • The DataHandler class for managing datasets and associated metadata

  • Utility functions for loading structured data products

  • Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations

This module serves as a central interface between raw data and modeling components.

Authors: Jennifer Pollack <jennifer.pollack@cea.fr>, Tobias Liaudat <tobiasliaudat@gmail.com>

Functions

extract_star_data(data, train_key, test_key)

Extract and concatenate star-related data from training and test datasets.

get_data_array(data, run_type[, key, ...])

Retrieve data from dataset depending on run type.

Classes

DataHandler(dataset_type, data_params, ...)

DataHandler for WaveDiff PSF modeling.

class wf_psf.data.data_handler.DataHandler(dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True, dataset: dict | list | None = None, sed_data: dict | list | None = None)[source]

Bases: object

DataHandler for WaveDiff PSF modeling.

This class manages loading, preprocessing, and TensorFlow conversion of datasets used for PSF model training, testing, and inference in the WaveDiff framework.

Parameters:
  • dataset_type (str) – Indicates the dataset mode (“train”, “test”, or “inference”).

  • data_params (RecursiveNamespace) – Configuration object containing dataset parameters (e.g., file paths, preprocessing flags).

  • simPSF (PSFSimulator) – An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format.

  • n_bins_lambda (int) – Number of wavelength bins used to discretize SEDs.

  • load_data (bool, optional) – If True (default), loads and processes data during initialization. If False, data loading must be triggered explicitly.

  • dataset (dict or list, optional) – If provided, uses this pre-loaded dataset instead of triggering automatic loading.

  • sed_data (dict or list, optional) – If provided, uses this SED data directly instead of extracting it from the dataset.

dataset_type

Indicates the dataset mode (“train”, “test”, or “inference”).

Type:

str

data_params

Configuration parameters for data access and structure.

Type:

RecursiveNamespace

simPSF

Simulator used to transform SEDs into TensorFlow-ready tensors.

Type:

PSFSimulator

n_bins_lambda

Number of wavelength bins in the SED representation.

Type:

int

load_data_on_init

Whether data was loaded automatically during initialization.

Type:

bool

dataset

Loaded dataset including keys such as ‘positions’, ‘stars’, ‘noisy_stars’, or similar.

Type:

dict

sed_data

TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features].

Type:

tf.Tensor

Attributes:
tf_positions

Get positions as TensorFlow tensor.

Methods

load_dataset()

Load dataset.

process_sed_data(sed_data)

Generate and process SED (Spectral Energy Distribution) data.

validate_and_process_dataset()

Validate the dataset structure and convert fields to TensorFlow tensors.

load_dataset()[source]

Load dataset.

Load the dataset based on the specified dataset type.

process_sed_data(sed_data)[source]

Generate and process SED (Spectral Energy Distribution) data.

This method transforms raw SED inputs into TensorFlow tensors suitable for model input. It generates wavelength-binned SED elements using the PSF simulator, converts the result into a tensor, and transposes it to match the expected shape for training or inference.

Parameters:

sed_data (list or array-like) – A list or array of raw SEDs, where each SED is typically a vector of flux values or coefficients. These will be processed using the PSF simulator.

Raises:

ValueError – If sed_data is None.

Notes

The resulting tensor is stored in self.sed_data and has shape (num_samples, n_bins_lambda, n_components), where:

  • num_samples is the number of SEDs,

  • n_bins_lambda is the number of wavelength bins,

  • n_components is the number of components per SED (e.g., filters or basis terms).

The intermediate tensor is created with tf.float64 for precision during generation, but is converted to tf.float32 after processing for use in training.

property tf_positions

Get positions as TensorFlow tensor.

validate_and_process_dataset()[source]

Validate the dataset structure and convert fields to TensorFlow tensors.

wf_psf.data.data_handler.extract_star_data(data, train_key: str, test_key: str) ndarray[source]

Extract and concatenate star-related data from training and test datasets.

This function retrieves arrays (e.g., postage stamps, masks, positions) from both the training and test datasets using the specified keys, converts them to NumPy if necessary, and concatenates them along the first axis.

Parameters:
  • data (DataConfigHandler) – Object containing training and test datasets.

  • train_key (str) – Key to retrieve data from the training dataset (e.g., ‘noisy_stars’, ‘masks’).

  • test_key (str) – Key to retrieve data from the test dataset (e.g., ‘stars’, ‘masks’).

Returns:

Concatenated NumPy array containing the selected data from both training and test sets.

Return type:

np.ndarray

Raises:

KeyError – If either the training or test dataset does not contain the requested key.

Notes

  • Designed for datasets with separate train/test splits, such as when evaluating metrics on held-out data.

  • TensorFlow tensors are automatically converted to NumPy arrays.

  • Requires eager execution if TensorFlow tensors are present.

wf_psf.data.data_handler.get_data_array(data, run_type: str, key: str | None = None, train_key: str | None = None, test_key: str | None = None, allow_missing: bool = True) ndarray | None[source]

Retrieve data from dataset depending on run type.

This function provides a unified interface for accessing data across different execution contexts (training, simulation, metrics, inference). It handles key resolution with sensible fallbacks and optional missing data tolerance.

Parameters:
  • data (DataConfigHandler) – Dataset object containing training, test, or inference data. Expected to have methods compatible with the specified run_type.

  • run_type ({"training", "simulation", "metrics", "inference"}) –

    Execution context that determines how data is retrieved:

    • ”training”, “simulation”, “metrics”: Uses extract_star_data function

    • ”inference”: Retrieves data directly from dataset using key lookup

  • key (str, optional) – Primary key for data lookup. Used directly for inference run_type. If None, falls back to train_key value. Default is None.

  • train_key (str, optional) – Key for training dataset access. If None and key is provided, defaults to key value. Default is None.

  • test_key (str, optional) – Key for test dataset access. If None, defaults to the resolved train_key value. Default is None.

  • allow_missing (bool, default True) –

    Control behavior when data is missing or keys are not found:

    • True: Return None instead of raising exceptions

    • False: Raise appropriate exceptions (KeyError, ValueError)

Returns:

Retrieved data as NumPy array. Returns None only when allow_missing=True and the requested data is not available.

Return type:

np.ndarray or None

Raises:
  • ValueError – If run_type is not one of the supported values, or if no key can be resolved for the operation and allow_missing=False.

  • KeyError – If the specified key is not found in the dataset and allow_missing=False.

Notes

Key resolution follows this priority order:

  1. train_key = train_key or key

  2. test_key = test_key or resolved_train_key

  3. key = key or resolved_train_key (for inference fallback)

For TensorFlow tensors, the .numpy() method is called to convert to NumPy. Other data types are converted using np.asarray().

Examples

>>> # Training data retrieval
>>> train_data = get_data_array(data, "training", train_key="noisy_stars")
>>> # Inference with fallback handling
>>> inference_data = get_data_array(data, "inference", key="positions",
...                                  allow_missing=True)
>>> if inference_data is None:
...     print("No inference data available")
>>> # Using key parameter for both train and inference
>>> result = get_data_array(data, "inference", key="positions")